diff --git a/core/state/statedb_hooked_test.go b/core/state/statedb_hooked_test.go index 9abd76b02db8..5f82ed06d0f1 100644 --- a/core/state/statedb_hooked_test.go +++ b/core/state/statedb_hooked_test.go @@ -35,7 +35,7 @@ func TestBurn(t *testing.T) { // the following occur: // 1. contract B creates contract A // 2. contract A is destructed - // 3. constract B sends ether to A + // 3. contract B sends ether to A var burned = new(uint256.Int) s, _ := New(types.EmptyRootHash, NewDatabaseForTesting()) diff --git a/core/txpool/blobpool/blobpool.go b/core/txpool/blobpool/blobpool.go index 0352ea978394..02d339f99c2d 100644 --- a/core/txpool/blobpool/blobpool.go +++ b/core/txpool/blobpool/blobpool.go @@ -1714,3 +1714,53 @@ func (p *BlobPool) Status(hash common.Hash) txpool.TxStatus { } return txpool.TxStatusUnknown } + +// Clear implements txpool.SubPool, removing all tracked transactions +// from the blob pool and persistent store. +func (p *BlobPool) Clear() { + p.lock.Lock() + defer p.lock.Unlock() + + // manually iterating and deleting every entry is super sub-optimal + // However, Clear is not currently used in production so + // performance is not critical at the moment. + for hash := range p.lookup.txIndex { + id, _ := p.lookup.storeidOfTx(hash) + if err := p.store.Delete(id); err != nil { + log.Warn("failed to delete blob tx from backing store", "err", err) + } + } + for hash := range p.lookup.blobIndex { + id, _ := p.lookup.storeidOfBlob(hash) + if err := p.store.Delete(id); err != nil { + log.Warn("failed to delete blob from backing store", "err", err) + } + } + + // unreserve each tracked account. Ideally, we could just clear the + // reservation map in the parent txpool context. However, if we clear in + // parent context, to avoid exposing the subpool lock, we have to lock the + // reservations and then lock each subpool. + // + // This creates the potential for a deadlock situation: + // + // * TxPool.Clear locks the reservations + // * a new transaction is received which locks the subpool mutex + // * TxPool.Clear attempts to lock subpool mutex + // + // The transaction addition may attempt to reserve the sender addr which + // can't happen until Clear releases the reservation lock. Clear cannot + // acquire the subpool lock until the transaction addition is completed. + for acct, _ := range p.index { + p.reserve(acct, false) + } + p.lookup = newLookup() + p.index = make(map[common.Address][]*blobTxMeta) + p.spent = make(map[common.Address]*uint256.Int) + + var ( + basefee = uint256.MustFromBig(eip1559.CalcBaseFee(p.chain.Config(), p.head)) + blobfee = uint256.NewInt(params.BlobTxMinBlobGasprice) + ) + p.evict = newPriceHeap(basefee, blobfee, p.index) +} diff --git a/core/txpool/legacypool/legacypool.go b/core/txpool/legacypool/legacypool.go index f7495dd39f8a..89ff86df0221 100644 --- a/core/txpool/legacypool/legacypool.go +++ b/core/txpool/legacypool/legacypool.go @@ -1961,3 +1961,44 @@ func (t *lookup) RemotesBelowTip(threshold *big.Int) types.Transactions { func numSlots(tx *types.Transaction) int { return int((tx.Size() + txSlotSize - 1) / txSlotSize) } + +// Clear implements txpool.SubPool, removing all tracked txs from the pool +// and rotating the journal. +func (pool *LegacyPool) Clear() { + pool.mu.Lock() + defer pool.mu.Unlock() + + // unreserve each tracked account. Ideally, we could just clear the + // reservation map in the parent txpool context. However, if we clear in + // parent context, to avoid exposing the subpool lock, we have to lock the + // reservations and then lock each subpool. + // + // This creates the potential for a deadlock situation: + // + // * TxPool.Clear locks the reservations + // * a new transaction is received which locks the subpool mutex + // * TxPool.Clear attempts to lock subpool mutex + // + // The transaction addition may attempt to reserve the sender addr which + // can't happen until Clear releases the reservation lock. Clear cannot + // acquire the subpool lock until the transaction addition is completed. + for _, tx := range pool.all.remotes { + senderAddr, _ := types.Sender(pool.signer, tx) + pool.reserve(senderAddr, false) + } + for localSender, _ := range pool.locals.accounts { + pool.reserve(localSender, false) + } + + pool.all = newLookup() + pool.priced = newPricedList(pool.all) + pool.pending = make(map[common.Address]*list) + pool.queue = make(map[common.Address]*list) + + if !pool.config.NoLocals && pool.config.Journal != "" { + pool.journal = newTxJournal(pool.config.Journal) + if err := pool.journal.rotate(pool.local()); err != nil { + log.Warn("Failed to rotate transaction journal", "err", err) + } + } +} diff --git a/core/txpool/subpool.go b/core/txpool/subpool.go index 180facd217f7..9ee0a69c0be9 100644 --- a/core/txpool/subpool.go +++ b/core/txpool/subpool.go @@ -168,4 +168,7 @@ type SubPool interface { // Status returns the known status (unknown/pending/queued) of a transaction // identified by their hashes. Status(hash common.Hash) TxStatus + + // Clear removes all tracked transactions from the pool + Clear() } diff --git a/core/txpool/txpool.go b/core/txpool/txpool.go index 54ae3be56948..ce455e806e50 100644 --- a/core/txpool/txpool.go +++ b/core/txpool/txpool.go @@ -497,3 +497,10 @@ func (p *TxPool) Sync() error { return errors.New("pool already terminated") } } + +// Clear removes all tracked txs from the subpools. +func (p *TxPool) Clear() { + for _, subpool := range p.subpools { + subpool.Clear() + } +} diff --git a/core/verkle_witness_test.go b/core/verkle_witness_test.go index 5a4210cdabe3..b3088f0f901d 100644 --- a/core/verkle_witness_test.go +++ b/core/verkle_witness_test.go @@ -338,7 +338,7 @@ func TestProcessVerkleInvalidContractCreation(t *testing.T) { } } } else if bytes.Equal(stemStateDiff.Stem[:], tx1ContractStem) { - // For this contract creation, check that only the accound header and storage slot 41 + // For this contract creation, check that only the account header and storage slot 41 // are found in the witness. for _, suffixDiff := range stemStateDiff.SuffixDiffs { if suffixDiff.Suffix != 105 && suffixDiff.Suffix != 0 && suffixDiff.Suffix != 1 { diff --git a/eth/catalyst/simulated_beacon.go b/eth/catalyst/simulated_beacon.go index db46afc30d63..a24ff5210119 100644 --- a/eth/catalyst/simulated_beacon.go +++ b/eth/catalyst/simulated_beacon.go @@ -21,7 +21,6 @@ import ( "crypto/sha256" "errors" "fmt" - "math/big" "sync" "time" @@ -34,7 +33,6 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/node" - "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" ) @@ -287,12 +285,7 @@ func (c *SimulatedBeacon) Commit() common.Hash { // Rollback un-sends previously added transactions. func (c *SimulatedBeacon) Rollback() { - // Flush all transactions from the transaction pools - maxUint256 := new(big.Int).Sub(new(big.Int).Lsh(common.Big1, 256), common.Big1) - c.eth.TxPool().SetGasTip(maxUint256) - // Set the gas tip back to accept new transactions - // TODO (Marius van der Wijden): set gas tip to parameter passed by config - c.eth.TxPool().SetGasTip(big.NewInt(params.GWei)) + c.eth.TxPool().Clear() } // Fork sets the head to the provided hash. diff --git a/go.mod b/go.mod index 253c55c7f97a..1928f7b5c5dc 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,6 @@ require ( github.com/jackpal/go-nat-pmp v1.0.2 github.com/jedisct1/go-minisign v0.0.0-20230811132847-661be99b8267 github.com/karalabe/hid v1.0.1-0.20240306101548-573246063e52 - github.com/kilic/bls12-381 v0.1.0 github.com/kylelemons/godebug v1.1.0 github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 @@ -123,6 +122,7 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.4 // indirect github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/kilic/bls12-381 v0.1.0 // indirect github.com/klauspost/compress v1.16.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/kr/pretty v0.3.1 // indirect diff --git a/oss-fuzz.sh b/oss-fuzz.sh index 50491b915561..5e4aa1c25336 100644 --- a/oss-fuzz.sh +++ b/oss-fuzz.sh @@ -160,6 +160,10 @@ compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzG1Add fuzz_g1_add\ $repo/tests/fuzzers/bls12381/bls12381_test.go +compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ + FuzzCrossG1Mul fuzz_cross_g1_mul\ + $repo/tests/fuzzers/bls12381/bls12381_test.go + compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzG1Mul fuzz_g1_mul\ $repo/tests/fuzzers/bls12381/bls12381_test.go @@ -172,6 +176,10 @@ compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzG2Add fuzz_g2_add \ $repo/tests/fuzzers/bls12381/bls12381_test.go +compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ + FuzzCrossG2Mul fuzz_cross_g2_mul\ + $repo/tests/fuzzers/bls12381/bls12381_test.go + compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzG2Mul fuzz_g2_mul\ $repo/tests/fuzzers/bls12381/bls12381_test.go @@ -204,6 +212,10 @@ compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzCrossG2Add fuzz_cross_g2_add \ $repo/tests/fuzzers/bls12381/bls12381_test.go +compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ + FuzzCrossG2MultiExp fuzz_cross_g2_multiexp \ + $repo/tests/fuzzers/bls12381/bls12381_test.go + compile_fuzzer github.com/ethereum/go-ethereum/tests/fuzzers/bls12381 \ FuzzCrossPairing fuzz_cross_pairing\ $repo/tests/fuzzers/bls12381/bls12381_test.go diff --git a/rpc/client_test.go b/rpc/client_test.go index 49f2350b404d..6c1a4f8f6c00 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -38,6 +38,8 @@ import ( ) func TestClientRequest(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() client := DialInProc(server) @@ -53,6 +55,8 @@ func TestClientRequest(t *testing.T) { } func TestClientResponseType(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() client := DialInProc(server) @@ -71,6 +75,8 @@ func TestClientResponseType(t *testing.T) { // This test checks calling a method that returns 'null'. func TestClientNullResponse(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() @@ -91,6 +97,8 @@ func TestClientNullResponse(t *testing.T) { // This test checks that server-returned errors with code and data come out of Client.Call. func TestClientErrorData(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() client := DialInProc(server) @@ -121,6 +129,8 @@ func TestClientErrorData(t *testing.T) { } func TestClientBatchRequest(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() client := DialInProc(server) @@ -172,6 +182,8 @@ func TestClientBatchRequest(t *testing.T) { // This checks that, for HTTP connections, the length of batch responses is validated to // match the request exactly. func TestClientBatchRequest_len(t *testing.T) { + t.Parallel() + b, err := json.Marshal([]jsonrpcMessage{ {Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)}, {Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)}, @@ -188,6 +200,8 @@ func TestClientBatchRequest_len(t *testing.T) { t.Cleanup(s.Close) t.Run("too-few", func(t *testing.T) { + t.Parallel() + client, err := Dial(s.URL) if err != nil { t.Fatal("failed to dial test server:", err) @@ -218,6 +232,8 @@ func TestClientBatchRequest_len(t *testing.T) { }) t.Run("too-many", func(t *testing.T) { + t.Parallel() + client, err := Dial(s.URL) if err != nil { t.Fatal("failed to dial test server:", err) @@ -249,6 +265,8 @@ func TestClientBatchRequest_len(t *testing.T) { // This checks that the client can handle the case where the server doesn't // respond to all requests in a batch. func TestClientBatchRequestLimit(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() server.SetBatchLimits(2, 100000) @@ -285,6 +303,8 @@ func TestClientBatchRequestLimit(t *testing.T) { } func TestClientNotify(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() client := DialInProc(server) @@ -392,6 +412,8 @@ func testClientCancel(transport string, t *testing.T) { } func TestClientSubscribeInvalidArg(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() client := DialInProc(server) @@ -422,6 +444,8 @@ func TestClientSubscribeInvalidArg(t *testing.T) { } func TestClientSubscribe(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() client := DialInProc(server) @@ -454,6 +478,8 @@ func TestClientSubscribe(t *testing.T) { // In this test, the connection drops while Subscribe is waiting for a response. func TestClientSubscribeClose(t *testing.T) { + t.Parallel() + server := newTestServer() service := ¬ificationTestService{ gotHangSubscriptionReq: make(chan struct{}), @@ -498,6 +524,8 @@ func TestClientSubscribeClose(t *testing.T) { // This test reproduces https://github.com/ethereum/go-ethereum/issues/17837 where the // client hangs during shutdown when Unsubscribe races with Client.Close. func TestClientCloseUnsubscribeRace(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() @@ -540,6 +568,8 @@ func (b *unsubscribeBlocker) readBatch() ([]*jsonrpcMessage, bool, error) { // not respond. // It reproducers the issue https://github.com/ethereum/go-ethereum/issues/30156 func TestUnsubscribeTimeout(t *testing.T) { + t.Parallel() + srv := NewServer() srv.RegisterName("nftest", new(notificationTestService)) @@ -674,6 +704,8 @@ func TestClientSubscriptionChannelClose(t *testing.T) { // This test checks that Client doesn't lock up when a single subscriber // doesn't read subscription events. func TestClientNotificationStorm(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() @@ -726,6 +758,8 @@ func TestClientNotificationStorm(t *testing.T) { } func TestClientSetHeader(t *testing.T) { + t.Parallel() + var gotHeader bool srv := newTestServer() httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -762,6 +796,8 @@ func TestClientSetHeader(t *testing.T) { } func TestClientHTTP(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() @@ -804,6 +840,8 @@ func TestClientHTTP(t *testing.T) { } func TestClientReconnect(t *testing.T) { + t.Parallel() + startServer := func(addr string) (*Server, net.Listener) { srv := newTestServer() l, err := net.Listen("tcp", addr) diff --git a/rpc/http_test.go b/rpc/http_test.go index ad86ca15aebd..6c268b62928d 100644 --- a/rpc/http_test.go +++ b/rpc/http_test.go @@ -58,24 +58,34 @@ func confirmRequestValidationCode(t *testing.T, method, contentType, body string } func TestHTTPErrorResponseWithDelete(t *testing.T) { + t.Parallel() + confirmRequestValidationCode(t, http.MethodDelete, contentType, "", http.StatusMethodNotAllowed) } func TestHTTPErrorResponseWithPut(t *testing.T) { + t.Parallel() + confirmRequestValidationCode(t, http.MethodPut, contentType, "", http.StatusMethodNotAllowed) } func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) { + t.Parallel() + body := make([]rune, defaultBodyLimit+1) confirmRequestValidationCode(t, http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge) } func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) { + t.Parallel() + confirmRequestValidationCode(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType) } func TestHTTPErrorResponseWithValidRequest(t *testing.T) { + t.Parallel() + confirmRequestValidationCode(t, http.MethodPost, contentType, "", 0) } @@ -101,11 +111,15 @@ func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body } func TestHTTPResponseWithEmptyGet(t *testing.T) { + t.Parallel() + confirmHTTPRequestYieldsStatusCode(t, http.MethodGet, "", "", http.StatusOK) } // This checks that maxRequestContentLength is not applied to the response of a request. func TestHTTPRespBodyUnlimited(t *testing.T) { + t.Parallel() + const respLength = defaultBodyLimit * 3 s := NewServer() @@ -132,6 +146,8 @@ func TestHTTPRespBodyUnlimited(t *testing.T) { // Tests that an HTTP error results in an HTTPError instance // being returned with the expected attributes. func TestHTTPErrorResponse(t *testing.T) { + t.Parallel() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "error has occurred!", http.StatusTeapot) })) @@ -169,6 +185,8 @@ func TestHTTPErrorResponse(t *testing.T) { } func TestHTTPPeerInfo(t *testing.T) { + t.Parallel() + s := newTestServer() defer s.Stop() ts := httptest.NewServer(s) @@ -205,6 +223,8 @@ func TestHTTPPeerInfo(t *testing.T) { } func TestNewContextWithHeaders(t *testing.T) { + t.Parallel() + expectedHeaders := 0 server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { for i := 0; i < expectedHeaders; i++ { diff --git a/rpc/server_test.go b/rpc/server_test.go index 9d1c7fb5f0fe..9ee545d81ade 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -29,6 +29,8 @@ import ( ) func TestServerRegisterName(t *testing.T) { + t.Parallel() + server := NewServer() service := new(testService) @@ -53,6 +55,8 @@ func TestServerRegisterName(t *testing.T) { } func TestServer(t *testing.T) { + t.Parallel() + files, err := os.ReadDir("testdata") if err != nil { t.Fatal("where'd my testdata go?") @@ -64,6 +68,8 @@ func TestServer(t *testing.T) { path := filepath.Join("testdata", f.Name()) name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name())) t.Run(name, func(t *testing.T) { + t.Parallel() + runTestScript(t, path) }) } @@ -116,6 +122,8 @@ func runTestScript(t *testing.T, file string) { // This test checks that responses are delivered for very short-lived connections that // only carry a single request. func TestServerShortLivedConn(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() @@ -156,6 +164,8 @@ func TestServerShortLivedConn(t *testing.T) { } func TestServerBatchResponseSizeLimit(t *testing.T) { + t.Parallel() + server := newTestServer() defer server.Stop() server.SetBatchLimits(100, 60) diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go index ab40ab169ff6..e52f390adb94 100644 --- a/rpc/subscription_test.go +++ b/rpc/subscription_test.go @@ -33,6 +33,8 @@ import ( ) func TestNewID(t *testing.T) { + t.Parallel() + hexchars := "0123456789ABCDEFabcdef" for i := 0; i < 100; i++ { id := string(NewID()) @@ -54,6 +56,8 @@ func TestNewID(t *testing.T) { } func TestSubscriptions(t *testing.T) { + t.Parallel() + var ( namespaces = []string{"eth", "bzz"} service = ¬ificationTestService{} @@ -132,6 +136,8 @@ func TestSubscriptions(t *testing.T) { // This test checks that unsubscribing works. func TestServerUnsubscribe(t *testing.T) { + t.Parallel() + p1, p2 := net.Pipe() defer p2.Close() @@ -260,6 +266,8 @@ func BenchmarkNotify(b *testing.B) { } func TestNotify(t *testing.T) { + t.Parallel() + out := new(bytes.Buffer) id := ID("test") notifier := &Notifier{ diff --git a/rpc/types_test.go b/rpc/types_test.go index aba40b5863f4..9dd6fa650807 100644 --- a/rpc/types_test.go +++ b/rpc/types_test.go @@ -26,6 +26,8 @@ import ( ) func TestBlockNumberJSONUnmarshal(t *testing.T) { + t.Parallel() + tests := []struct { input string mustFail bool @@ -70,6 +72,8 @@ func TestBlockNumberJSONUnmarshal(t *testing.T) { } func TestBlockNumberOrHash_UnmarshalJSON(t *testing.T) { + t.Parallel() + tests := []struct { input string mustFail bool @@ -131,6 +135,8 @@ func TestBlockNumberOrHash_UnmarshalJSON(t *testing.T) { } func TestBlockNumberOrHash_WithNumber_MarshalAndUnmarshal(t *testing.T) { + t.Parallel() + tests := []struct { name string number int64 @@ -144,6 +150,8 @@ func TestBlockNumberOrHash_WithNumber_MarshalAndUnmarshal(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + t.Parallel() + bnh := BlockNumberOrHashWithNumber(BlockNumber(test.number)) marshalled, err := json.Marshal(bnh) if err != nil { @@ -162,6 +170,8 @@ func TestBlockNumberOrHash_WithNumber_MarshalAndUnmarshal(t *testing.T) { } func TestBlockNumberOrHash_StringAndUnmarshal(t *testing.T) { + t.Parallel() + tests := []BlockNumberOrHash{ BlockNumberOrHashWithNumber(math.MaxInt64), BlockNumberOrHashWithNumber(PendingBlockNumber), diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index c6ea325d2926..10a998b3512c 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -174,6 +174,8 @@ func TestWebsocketLargeRead(t *testing.T) { } func TestWebsocketPeerInfo(t *testing.T) { + t.Parallel() + var ( s = newTestServer() ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) @@ -259,6 +261,8 @@ func TestClientWebsocketPing(t *testing.T) { // This checks that the websocket transport can deal with large messages. func TestClientWebsocketLargeMessage(t *testing.T) { + t.Parallel() + var ( srv = NewServer() httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) diff --git a/tests/fuzzers/bls12381/bls12381_fuzz.go b/tests/fuzzers/bls12381/bls12381_fuzz.go index 74ea6f52a75e..a3e0e9f72b0a 100644 --- a/tests/fuzzers/bls12381/bls12381_fuzz.go +++ b/tests/fuzzers/bls12381/bls12381_fuzz.go @@ -31,42 +31,33 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/ethereum/go-ethereum/common" - bls12381 "github.com/kilic/bls12-381" blst "github.com/supranational/blst/bindings/go" ) func fuzzG1SubgroupChecks(data []byte) int { input := bytes.NewReader(data) - kpG1, cpG1, blG1, err := getG1Points(input) + cpG1, blG1, err := getG1Points(input) if err != nil { return 0 } - inSubGroupKilic := bls12381.NewG1().InCorrectSubgroup(kpG1) inSubGroupGnark := cpG1.IsInSubGroup() inSubGroupBLST := blG1.InG1() - if inSubGroupKilic != inSubGroupGnark { - panic(fmt.Sprintf("differing subgroup check, kilic %v, gnark %v", inSubGroupKilic, inSubGroupGnark)) - } - if inSubGroupKilic != inSubGroupBLST { - panic(fmt.Sprintf("differing subgroup check, kilic %v, blst %v", inSubGroupKilic, inSubGroupBLST)) + if inSubGroupGnark != inSubGroupBLST { + panic(fmt.Sprintf("differing subgroup check, gnark %v, blst %v", inSubGroupGnark, inSubGroupBLST)) } return 1 } func fuzzG2SubgroupChecks(data []byte) int { input := bytes.NewReader(data) - kpG2, cpG2, blG2, err := getG2Points(input) + gpG2, blG2, err := getG2Points(input) if err != nil { return 0 } - inSubGroupKilic := bls12381.NewG2().InCorrectSubgroup(kpG2) - inSubGroupGnark := cpG2.IsInSubGroup() + inSubGroupGnark := gpG2.IsInSubGroup() inSubGroupBLST := blG2.InG2() - if inSubGroupKilic != inSubGroupGnark { - panic(fmt.Sprintf("differing subgroup check, kilic %v, gnark %v", inSubGroupKilic, inSubGroupGnark)) - } - if inSubGroupKilic != inSubGroupBLST { - panic(fmt.Sprintf("differing subgroup check, kilic %v, blst %v", inSubGroupKilic, inSubGroupBLST)) + if inSubGroupGnark != inSubGroupBLST { + panic(fmt.Sprintf("differing subgroup check, gnark %v, blst %v", inSubGroupGnark, inSubGroupBLST)) } return 1 } @@ -75,38 +66,28 @@ func fuzzCrossPairing(data []byte) int { input := bytes.NewReader(data) // get random G1 points - kpG1, cpG1, blG1, err := getG1Points(input) + cpG1, blG1, err := getG1Points(input) if err != nil { return 0 } // get random G2 points - kpG2, cpG2, blG2, err := getG2Points(input) + cpG2, blG2, err := getG2Points(input) if err != nil { return 0 } - // compute pairing using geth - engine := bls12381.NewEngine() - engine.AddPair(kpG1, kpG2) - kResult := engine.Result() - // compute pairing using gnark cResult, err := gnark.Pair([]gnark.G1Affine{*cpG1}, []gnark.G2Affine{*cpG2}) if err != nil { panic(fmt.Sprintf("gnark/bls12381 encountered error: %v", err)) } - // compare result - if !(bytes.Equal(cResult.Marshal(), bls12381.NewGT().ToBytes(kResult))) { - panic("pairing mismatch gnark / geth ") - } - // compute pairing using blst blstResult := blst.Fp12MillerLoop(blG2, blG1) blstResult.FinalExp() res := massageBLST(blstResult.ToBendian()) - if !(bytes.Equal(res, bls12381.NewGT().ToBytes(kResult))) { + if !(bytes.Equal(res, cResult.Marshal())) { panic("pairing mismatch blst / geth") } @@ -141,32 +122,22 @@ func fuzzCrossG1Add(data []byte) int { input := bytes.NewReader(data) // get random G1 points - kp1, cp1, bl1, err := getG1Points(input) + cp1, bl1, err := getG1Points(input) if err != nil { return 0 } // get random G1 points - kp2, cp2, bl2, err := getG1Points(input) + cp2, bl2, err := getG1Points(input) if err != nil { return 0 } - // compute kp = kp1 + kp2 - g1 := bls12381.NewG1() - kp := bls12381.PointG1{} - g1.Add(&kp, kp1, kp2) - // compute cp = cp1 + cp2 _cp1 := new(gnark.G1Jac).FromAffine(cp1) _cp2 := new(gnark.G1Jac).FromAffine(cp2) cp := new(gnark.G1Affine).FromJacobian(_cp1.AddAssign(_cp2)) - // compare result - if !(bytes.Equal(cp.Marshal(), g1.ToBytes(&kp))) { - panic("G1 point addition mismatch gnark / geth ") - } - bl3 := blst.P1AffinesAdd([]*blst.P1Affine{bl1, bl2}) if !(bytes.Equal(cp.Marshal(), bl3.Serialize())) { panic("G1 point addition mismatch blst / geth ") @@ -179,34 +150,24 @@ func fuzzCrossG2Add(data []byte) int { input := bytes.NewReader(data) // get random G2 points - kp1, cp1, bl1, err := getG2Points(input) + gp1, bl1, err := getG2Points(input) if err != nil { return 0 } // get random G2 points - kp2, cp2, bl2, err := getG2Points(input) + gp2, bl2, err := getG2Points(input) if err != nil { return 0 } - // compute kp = kp1 + kp2 - g2 := bls12381.NewG2() - kp := bls12381.PointG2{} - g2.Add(&kp, kp1, kp2) - // compute cp = cp1 + cp2 - _cp1 := new(gnark.G2Jac).FromAffine(cp1) - _cp2 := new(gnark.G2Jac).FromAffine(cp2) - cp := new(gnark.G2Affine).FromJacobian(_cp1.AddAssign(_cp2)) - - // compare result - if !(bytes.Equal(cp.Marshal(), g2.ToBytes(&kp))) { - panic("G2 point addition mismatch gnark / geth ") - } + _gp1 := new(gnark.G2Jac).FromAffine(gp1) + _gp2 := new(gnark.G2Jac).FromAffine(gp2) + gp := new(gnark.G2Affine).FromJacobian(_gp1.AddAssign(_gp2)) bl3 := blst.P2AffinesAdd([]*blst.P2Affine{bl1, bl2}) - if !(bytes.Equal(cp.Marshal(), bl3.Serialize())) { + if !(bytes.Equal(gp.Marshal(), bl3.Serialize())) { panic("G1 point addition mismatch blst / geth ") } @@ -216,10 +177,10 @@ func fuzzCrossG2Add(data []byte) int { func fuzzCrossG1MultiExp(data []byte) int { var ( input = bytes.NewReader(data) - gethScalars []*bls12381.Fr gnarkScalars []fr.Element - gethPoints []*bls12381.PointG1 gnarkPoints []gnark.G1Affine + blstScalars []*blst.Scalar + blstPoints []*blst.P1Affine ) // n random scalars (max 17) for i := 0; i < 17; i++ { @@ -229,50 +190,147 @@ func fuzzCrossG1MultiExp(data []byte) int { break } // get a random G1 point as basis - kp1, cp1, _, err := getG1Points(input) + cp1, bl1, err := getG1Points(input) if err != nil { break } - gethScalars = append(gethScalars, bls12381.NewFr().FromBytes(s.Bytes())) - var gnarkScalar = &fr.Element{} - gnarkScalar = gnarkScalar.SetBigInt(s) - gnarkScalars = append(gnarkScalars, *gnarkScalar) - gethPoints = append(gethPoints, new(bls12381.PointG1).Set(kp1)) + gnarkScalar := new(fr.Element).SetBigInt(s) + gnarkScalars = append(gnarkScalars, *gnarkScalar) gnarkPoints = append(gnarkPoints, *cp1) + + blstScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32)) + blstScalars = append(blstScalars, blstScalar) + blstPoints = append(blstPoints, bl1) } - if len(gethScalars) == 0 { + + if len(gnarkScalars) == 0 || len(gnarkScalars) != len(gnarkPoints) { return 0 } - // compute multi exponentiation - g1 := bls12381.NewG1() - kp := bls12381.PointG1{} - if _, err := g1.MultiExp(&kp, gethPoints, gethScalars); err != nil { - panic(fmt.Sprintf("G1 multi exponentiation errored (geth): %v", err)) - } - // note that geth/crypto/bls12381.MultiExp mutates the scalars slice (and sets all the scalars to zero) // gnark multi exp cp := new(gnark.G1Affine) cp.MultiExp(gnarkPoints, gnarkScalars, ecc.MultiExpConfig{}) - // compare result - gnarkRes := cp.Marshal() - gethRes := g1.ToBytes(&kp) - if !bytes.Equal(gnarkRes, gethRes) { - msg := fmt.Sprintf("G1 multi exponentiation mismatch gnark/geth.\ngnark: %x\ngeth: %x\ninput: %x\n ", - gnarkRes, gethRes, data) - panic(msg) + expectedGnark := multiExpG1Gnark(gnarkPoints, gnarkScalars) + if !bytes.Equal(cp.Marshal(), expectedGnark.Marshal()) { + panic("g1 multi exponentiation mismatch") + } + + // blst multi exp + expectedBlst := blst.P1AffinesMult(blstPoints, blstScalars, 256).ToAffine() + if !bytes.Equal(cp.Marshal(), expectedBlst.Serialize()) { + panic("g1 multi exponentiation mismatch, gnark/blst") + } + return 1 +} + +func fuzzCrossG1Mul(data []byte) int { + input := bytes.NewReader(data) + gp, blpAffine, err := getG1Points(input) + if err != nil { + return 0 + } + scalar, err := randomScalar(input, fp.Modulus()) + if err != nil { + return 0 + } + + blScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(scalar.Bytes(), 32)) + + blp := new(blst.P1) + blp.FromAffine(blpAffine) + + resBl := blp.Mult(blScalar) + resGeth := (new(gnark.G1Affine)).ScalarMultiplication(gp, scalar) + + if !bytes.Equal(resGeth.Marshal(), resBl.Serialize()) { + panic("bytes(blst.G1) != bytes(geth.G1)") + } + return 1 +} + +func fuzzCrossG2Mul(data []byte) int { + input := bytes.NewReader(data) + gp, blpAffine, err := getG2Points(input) + if err != nil { + return 0 + } + scalar, err := randomScalar(input, fp.Modulus()) + if err != nil { + return 0 + } + + blScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(scalar.Bytes(), 32)) + + blp := new(blst.P2) + blp.FromAffine(blpAffine) + + resBl := blp.Mult(blScalar) + resGeth := (new(gnark.G2Affine)).ScalarMultiplication(gp, scalar) + + if !bytes.Equal(resGeth.Marshal(), resBl.Serialize()) { + panic("bytes(blst.G1) != bytes(geth.G1)") + } + return 1 +} + +func fuzzCrossG2MultiExp(data []byte) int { + var ( + input = bytes.NewReader(data) + gnarkScalars []fr.Element + gnarkPoints []gnark.G2Affine + blstScalars []*blst.Scalar + blstPoints []*blst.P2Affine + ) + // n random scalars (max 17) + for i := 0; i < 17; i++ { + // note that geth/crypto/bls12381 works only with scalars <= 32bytes + s, err := randomScalar(input, fr.Modulus()) + if err != nil { + break + } + // get a random G1 point as basis + cp1, bl1, err := getG2Points(input) + if err != nil { + break + } + + gnarkScalar := new(fr.Element).SetBigInt(s) + gnarkScalars = append(gnarkScalars, *gnarkScalar) + gnarkPoints = append(gnarkPoints, *cp1) + + blstScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32)) + blstScalars = append(blstScalars, blstScalar) + blstPoints = append(blstPoints, bl1) + } + + if len(gnarkScalars) == 0 || len(gnarkScalars) != len(gnarkPoints) { + return 0 + } + + // gnark multi exp + cp := new(gnark.G2Affine) + cp.MultiExp(gnarkPoints, gnarkScalars, ecc.MultiExpConfig{}) + + expectedGnark := multiExpG2Gnark(gnarkPoints, gnarkScalars) + if !bytes.Equal(cp.Marshal(), expectedGnark.Marshal()) { + panic("g1 multi exponentiation mismatch") } + // blst multi exp + expectedBlst := blst.P2AffinesMult(blstPoints, blstScalars, 256).ToAffine() + if !bytes.Equal(cp.Marshal(), expectedBlst.Serialize()) { + panic("g1 multi exponentiation mismatch, gnark/blst") + } return 1 } -func getG1Points(input io.Reader) (*bls12381.PointG1, *gnark.G1Affine, *blst.P1Affine, error) { +func getG1Points(input io.Reader) (*gnark.G1Affine, *blst.P1Affine, error) { // sample a random scalar s, err := randomScalar(input, fp.Modulus()) if err != nil { - return nil, nil, nil, err + return nil, nil, err } // compute a random point @@ -281,18 +339,6 @@ func getG1Points(input io.Reader) (*bls12381.PointG1, *gnark.G1Affine, *blst.P1A cp.ScalarMultiplication(&g1Gen, s) cpBytes := cp.Marshal() - // marshal gnark point -> geth point - g1 := bls12381.NewG1() - kp, err := g1.FromBytes(cpBytes) - if err != nil { - panic(fmt.Sprintf("Could not marshal gnark.G1 -> geth.G1: %v", err)) - } - - gnarkRes := g1.ToBytes(kp) - if !bytes.Equal(gnarkRes, cpBytes) { - panic(fmt.Sprintf("bytes(gnark.G1) != bytes(geth.G1)\ngnark.G1: %x\ngeth.G1: %x\n", gnarkRes, cpBytes)) - } - // marshal gnark point -> blst point scalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32)) p1 := new(blst.P1Affine).From(scalar) @@ -301,43 +347,31 @@ func getG1Points(input io.Reader) (*bls12381.PointG1, *gnark.G1Affine, *blst.P1A panic(fmt.Sprintf("bytes(blst.G1) != bytes(geth.G1)\nblst.G1: %x\ngeth.G1: %x\n", blstRes, cpBytes)) } - return kp, cp, p1, nil + return cp, p1, nil } -func getG2Points(input io.Reader) (*bls12381.PointG2, *gnark.G2Affine, *blst.P2Affine, error) { +func getG2Points(input io.Reader) (*gnark.G2Affine, *blst.P2Affine, error) { // sample a random scalar s, err := randomScalar(input, fp.Modulus()) if err != nil { - return nil, nil, nil, err + return nil, nil, err } // compute a random point - cp := new(gnark.G2Affine) + gp := new(gnark.G2Affine) _, _, _, g2Gen := gnark.Generators() - cp.ScalarMultiplication(&g2Gen, s) - cpBytes := cp.Marshal() - - // marshal gnark point -> geth point - g2 := bls12381.NewG2() - kp, err := g2.FromBytes(cpBytes) - if err != nil { - panic(fmt.Sprintf("Could not marshal gnark.G2 -> geth.G2: %v", err)) - } - - gnarkRes := g2.ToBytes(kp) - if !bytes.Equal(gnarkRes, cpBytes) { - panic(fmt.Sprintf("bytes(gnark.G2) != bytes(geth.G2)\ngnark.G2: %x\ngeth.G2: %x\n", gnarkRes, cpBytes)) - } + gp.ScalarMultiplication(&g2Gen, s) + cpBytes := gp.Marshal() // marshal gnark point -> blst point // Left pad the scalar to 32 bytes scalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32)) p2 := new(blst.P2Affine).From(scalar) if !bytes.Equal(p2.Serialize(), cpBytes) { - panic("bytes(blst.G2) != bytes(geth.G2)") + panic("bytes(blst.G2) != bytes(bls12381.G2)") } - return kp, cp, p2, nil + return gp, p2, nil } func randomScalar(r io.Reader, max *big.Int) (k *big.Int, err error) { @@ -348,3 +382,29 @@ func randomScalar(r io.Reader, max *big.Int) (k *big.Int, err error) { } } } + +// multiExpG1Gnark is a naive implementation of G1 multi-exponentiation +func multiExpG1Gnark(gs []gnark.G1Affine, scalars []fr.Element) gnark.G1Affine { + res := gnark.G1Affine{} + for i := 0; i < len(gs); i++ { + tmp := new(gnark.G1Affine) + sb := scalars[i].Bytes() + scalarBytes := new(big.Int).SetBytes(sb[:]) + tmp.ScalarMultiplication(&gs[i], scalarBytes) + res.Add(&res, tmp) + } + return res +} + +// multiExpG1Gnark is a naive implementation of G1 multi-exponentiation +func multiExpG2Gnark(gs []gnark.G2Affine, scalars []fr.Element) gnark.G2Affine { + res := gnark.G2Affine{} + for i := 0; i < len(gs); i++ { + tmp := new(gnark.G2Affine) + sb := scalars[i].Bytes() + scalarBytes := new(big.Int).SetBytes(sb[:]) + tmp.ScalarMultiplication(&gs[i], scalarBytes) + res.Add(&res, tmp) + } + return res +} diff --git a/tests/fuzzers/bls12381/bls12381_test.go b/tests/fuzzers/bls12381/bls12381_test.go index fd782f7813f0..d4e5e20e04f7 100644 --- a/tests/fuzzers/bls12381/bls12381_test.go +++ b/tests/fuzzers/bls12381/bls12381_test.go @@ -27,6 +27,12 @@ func FuzzCrossPairing(f *testing.F) { }) } +func FuzzCrossG2MultiExp(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + fuzzCrossG2MultiExp(data) + }) +} + func FuzzCrossG1Add(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { fuzzCrossG1Add(data) @@ -51,9 +57,9 @@ func FuzzG1Add(f *testing.F) { }) } -func FuzzG1Mul(f *testing.F) { +func FuzzCrossG1Mul(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { - fuzz(blsG1Mul, data) + fuzzCrossG1Mul(data) }) } @@ -69,9 +75,9 @@ func FuzzG2Add(f *testing.F) { }) } -func FuzzG2Mul(f *testing.F) { +func FuzzCrossG2Mul(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { - fuzz(blsG2Mul, data) + fuzzCrossG2Mul(data) }) } @@ -110,3 +116,15 @@ func FuzzG2SubgroupChecks(f *testing.F) { fuzzG2SubgroupChecks(data) }) } + +func FuzzG2Mul(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + fuzz(blsG2Mul, data) + }) +} + +func FuzzG1Mul(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + fuzz(blsG1Mul, data) + }) +} diff --git a/version/version.go b/version/version.go index cbd59f3e9ad9..52dfa0281e54 100644 --- a/version/version.go +++ b/version/version.go @@ -19,6 +19,6 @@ package version const ( Major = 1 // Major version component of the current release Minor = 14 // Minor version component of the current release - Patch = 12 // Patch version component of the current release + Patch = 13 // Patch version component of the current release Meta = "unstable" // Version metadata to append to the version string )