Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Performance Issue #31

Merged
merged 1 commit into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ test test_race test_with_mock:
go test -v $(COVER_OPTS) -coverprofile=$(COVER_OUT) && go tool cover -html=$(COVER_OUT) -o $(COVER_HTML) && go tool cover -func=$(COVER_OUT) -o $(COVER_OUT)

test_fuzz:
go test -v -race -fuzz=FuzzMerkleTreeNew -fuzztime=30m -run ^FuzzMerkleTreeNew$
go test -v -race -fuzz=FuzzMerkleTreeNew -fuzztime=60m -run ^FuzzMerkleTreeNew$

test_ci_coverage:
go test -race -gcflags=all=-l -coverprofile=coverage.txt -covermode=atomic
Expand Down
50 changes: 18 additions & 32 deletions proof_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (m *MerkleTree) generateProofs() (err error) {
buffer, bufferSize := initBuffer(m.Leaves)
for step := 0; step < m.Depth; step++ {
bufferSize = fixOddNumOfNodes(buffer, bufferSize, step)
updateProofs(m.Proofs, buffer, bufferSize, step)
m.updateProofs(buffer, bufferSize, step)
for idx := 0; idx < bufferSize; idx += 2 {
leftIdx := idx << step
rightIdx := min(leftIdx+(1<<step), len(buffer)-1)
Expand All @@ -62,19 +62,21 @@ func (m *MerkleTree) generateProofsParallel() (err error) {
// Limit the number of workers to the previous level length.
numRoutines = min(numRoutines, bufferSize)
bufferSize = fixOddNumOfNodes(buffer, bufferSize, step)
updateProofsParallel(m.Proofs, buffer, bufferSize, step, m.NumLeaves)
var (
eg = new(errgroup.Group)
hashFunc = m.HashFunc
concatHashFunc = m.concatHashFunc
)
m.updateProofsParallel(buffer, bufferSize, step)
eg := new(errgroup.Group)
for startIdx := 0; startIdx < numRoutines; startIdx++ {
startIdx := startIdx << 1
eg.Go(func() error {
return workerProofGen(
hashFunc, concatHashFunc,
buffer, bufferSize, numRoutines, startIdx, step,
)
var err error
for i := startIdx; i < bufferSize; i += numRoutines << 1 {
leftIdx := i << step
rightIdx := min(leftIdx+(1<<step), len(buffer)-1)
buffer[leftIdx], err = m.HashFunc(m.concatHashFunc(buffer[leftIdx], buffer[rightIdx]))
if err != nil {
return err
}
}
return nil
})
}
if err = eg.Wait(); err != nil {
Expand All @@ -86,22 +88,6 @@ func (m *MerkleTree) generateProofsParallel() (err error) {
return
}

func workerProofGen(
hashFunc TypeHashFunc, concatHashFunc typeConcatHashFunc,
buffer [][]byte, bufferSize, numRoutine, startIdx, step int,
) error {
var err error
for i := startIdx; i < bufferSize; i += numRoutine << 1 {
leftIdx := i << step
rightIdx := min(leftIdx+(1<<step), len(buffer)-1)
buffer[leftIdx], err = hashFunc(concatHashFunc(buffer[leftIdx], buffer[rightIdx]))
if err != nil {
return err
}
}
return nil
}

// initProofs initializes the MerkleTree's Proofs with the appropriate size and depth.
// This is to reduce overhead of slice resizing during the generation process.
func (m *MerkleTree) initProofs() {
Expand Down Expand Up @@ -145,26 +131,26 @@ func fixOddNumOfNodes(buffer [][]byte, bufferSize, step int) int {
}

// updateProofs updates the proofs for all the leaves while constructing the Merkle Tree.
func updateProofs(proofs []*Proof, buffer [][]byte, bufferSize, step int) {
func (m *MerkleTree) updateProofs(buffer [][]byte, bufferSize, step int) {
batch := 1 << step
for i := 0; i < bufferSize; i += 2 {
updateProofInTwoBatches(proofs, buffer, i, batch, step)
updateProofInTwoBatches(m.Proofs, buffer, i, batch, step)
}
}

// updateProofsParallel updates the proofs for all the leaves while constructing the Merkle Tree in parallel.
func updateProofsParallel(proofs []*Proof, buffer [][]byte, bufferLength, step, numRoutines int) {
func (m *MerkleTree) updateProofsParallel(buffer [][]byte, bufferLength, step int) {
var (
batch = 1 << step
wg sync.WaitGroup
)
numRoutines = min(numRoutines, bufferLength)
numRoutines := min(m.NumRoutines, bufferLength)
wg.Add(numRoutines)
for startIdx := 0; startIdx < numRoutines; startIdx++ {
go func(startIdx int) {
defer wg.Done()
for i := startIdx; i < bufferLength; i += numRoutines << 1 {
updateProofInTwoBatches(proofs, buffer, i, batch, step)
updateProofInTwoBatches(m.Proofs, buffer, i, batch, step)
}
}(startIdx << 1)
}
Expand Down
86 changes: 11 additions & 75 deletions proof_gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"sync/atomic"
"testing"

"github.com/agiledragon/gomonkey/v2"
Expand Down Expand Up @@ -343,38 +344,31 @@ func TestMerkleTree_proofGen(t *testing.T) {
}
}

func dummyBuffer(size int) [][]byte {
buffer := make([][]byte, size)
for i := 0; i < size; i++ {
buffer[i] = []byte(fmt.Sprintf("dummy_buffer_%d", i))
}
return buffer
}

func TestMerkleTree_proofGenParallel(t *testing.T) {
patches := gomonkey.NewPatches()
defer patches.Reset()
var hashFuncCounter atomic.Uint32
type args struct {
config *Config
blocks []DataBlock
}
tests := []struct {
name string
args args
mock func()
wantErr bool
}{
{
name: "test_workerProofGen_err",
name: "test_goroutine_err",
args: args{
config: &Config{
HashFunc: mockHashFunc,
HashFunc: func(data []byte) ([]byte, error) {
if hashFuncCounter.Load() == 9 {
return nil, errors.New("test_goroutine_err")
}
hashFuncCounter.Add(1)
return mockHashFunc(data)
},
RunInParallel: true,
},
blocks: mockDataBlocks(5),
},
mock: func() {
patches.ApplyFuncReturn(workerProofGen, errors.New("test_workerProofGen_err"))
blocks: mockDataBlocks(4),
},
wantErr: true,
},
Expand All @@ -386,67 +380,9 @@ func TestMerkleTree_proofGenParallel(t *testing.T) {
t.Errorf("New() error = %v", err)
return
}
if tt.mock != nil {
tt.mock()
}
defer patches.Reset()
if err := m.generateProofsParallel(); (err != nil) != tt.wantErr {
t.Errorf("generateProofsParallel() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func Test_workerProofGen(t *testing.T) {
patches := gomonkey.NewPatches()
defer patches.Reset()
type args struct {
hashFunc TypeHashFunc
concatHashFunc typeConcatHashFunc
buffer [][]byte
bufferSize int
numRoutine int
startIdx int
step int
}
tests := []struct {
name string
args args
mock func()
wantErr bool
}{
{
name: "test_hash_func_err",
args: args{
hashFunc: mockHashFunc,
concatHashFunc: concatHash,
buffer: dummyBuffer(8),
bufferSize: 8,
numRoutine: 4,
startIdx: 0,
step: 0,
},
mock: func() {
patches.ApplyFunc(mockHashFunc,
func([]byte) ([]byte, error) {
return nil, errors.New("test_hash_func_err")
})
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.mock != nil {
tt.mock()
}
defer patches.Reset()
if err := workerProofGen(
tt.args.hashFunc, tt.args.concatHashFunc,
tt.args.buffer, tt.args.bufferSize, tt.args.numRoutine, tt.args.startIdx, tt.args.step,
); (err != nil) != tt.wantErr {
t.Errorf("workerProofGen() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
Loading