Skip to content

Commit

Permalink
fix: concurrency bug from the ongoingLookups
Browse files Browse the repository at this point in the history
  • Loading branch information
2color committed Dec 4, 2024
1 parent 811dce8 commit b4da9cd
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 96 deletions.
79 changes: 42 additions & 37 deletions server_cached_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"errors"
"sync/atomic"
"time"

Expand All @@ -26,6 +27,8 @@ var (
},
[]string{addrCacheStateLabel, addrQueryOriginLabel},
)

errNoValueAvailable = errors.New("no value available")
)

const (
Expand Down Expand Up @@ -107,85 +110,87 @@ func NewCacheFallbackIter(sourceIter iter.ResultIter[types.Record], router cache
router: router,
ctx: ctx,
cancel: cancel,
findPeersResult: make(chan types.PeerRecord, 1),
findPeersResult: make(chan types.PeerRecord),
ongoingLookups: atomic.Int32{},
}
}

func (it *cacheFallbackIter) Next() bool {
// load up current val from source iterator and avoid blocking on channel
// Try to get the next value from the source iterator first
if it.sourceIter.Next() {
val := it.sourceIter.Val()
handleRecord := func(id *peer.ID, record *types.PeerRecord) bool {
record.Addrs = it.router.withAddrsFromCache(addrQueryOriginProviders, id, record.Addrs)
if record.Addrs != nil { // if we have addrs, return them
if record.Addrs != nil {
it.current = iter.Result[types.Record]{Val: record}
return true
}
// If a record has no addrs, we need to look it up.
go it.dispatchFindPeer(*id)
if it.sourceIter.Next() { // In the meantime, we continue reading from source iterator if we have more results
it.current = it.sourceIter.Val()
return true
}
return it.ongoingLookups.Load() > 0 // If there are no more results from the source iterator, and no ongoing lookups, we're done.
logger.Infow("no cached addresses found in cacheFallbackIter, dispatching find peers", "peer", id)
// If a record has no addrs, we dispatch a lookup to find addresses
go it.dispatchFindPeer(*record)
// important to increment here since Next() may be called again synchronously
it.ongoingLookups.Add(1)

return it.Next() // Recursively call Next() to either read from sourceIter or wait for lookup result
}

switch val.Val.GetSchema() {
case types.SchemaBitswap:
if record, ok := val.Val.(*types.BitswapRecord); ok {
// we convert to peer record to handle uniformly
return handleRecord(record.ID, types.FromBitswapRecord(record))
}
case types.SchemaPeer:
if record, ok := val.Val.(*types.PeerRecord); ok {
return handleRecord(record.ID, record)
}
default:
// we don't know how to handle this schema, so we just return the record as is
it.current = val
return true
}
it.current = val // pass through unknown schemas
return true
}
// source iterator is exhausted, check if there are any peers left to look up

// If there are still ongoing lookups, wait for them
if it.ongoingLookups.Load() > 0 {
// if there are any ongoing lookups, return true to keep iterating
return true
logger.Infow("waiting for ongoing find peers result")
select {
case result, ok := <-it.findPeersResult:
if ok {
it.current = iter.Result[types.Record]{Val: &result}
return true
}
case <-it.ctx.Done():
return false
}
}
// if there are no ongoing lookups and the source iterator is exhausted, we're done

return false
}

func (it *cacheFallbackIter) dispatchFindPeer(pid peer.ID) {
it.ongoingLookups.Add(1)
func (it *cacheFallbackIter) Val() iter.Result[types.Record] {
if it.current.Val != nil || it.current.Err != nil {
return it.current
}
return iter.Result[types.Record]{Err: errNoValueAvailable}
}

func (it *cacheFallbackIter) dispatchFindPeer(record types.PeerRecord) {
defer it.ongoingLookups.Add(-1)
// FindPeers is weird in that it accepts a limit. But we only want one result, ideally from the libp2p router.
peersIt, err := it.router.FindPeers(it.ctx, pid, 1)
peersIt, err := it.router.FindPeers(it.ctx, *record.ID, 1)

if err != nil {
logger.Errorw("error looking up peer", "peer", pid, "error", err)
it.findPeersResult <- record // pass back the record with no addrs
return
}
peers, err := iter.ReadAllResults(peersIt)
if err != nil {
logger.Errorw("error reading find peers results", "peer", pid, "error", err)
it.findPeersResult <- record // pass back the record with no addrs
return
}
if len(peers) > 0 {
// If we found the peer, pass back
it.findPeersResult <- *peers[0]
} else {
logger.Errorw("no peer was found in cachedFallbackIter", "peer", pid)
}
}

func (it *cacheFallbackIter) Val() iter.Result[types.Record] {
select {
case <-it.ctx.Done():
return iter.Result[types.Record]{Err: it.ctx.Err()}
case foundPeer := <-it.findPeersResult:
// read from channel if available
return iter.Result[types.Record]{Val: &foundPeer}
default:
return it.current
it.findPeersResult <- record // pass back the record with no addrs
}
}

Expand Down
Loading

0 comments on commit b4da9cd

Please sign in to comment.