diff --git a/CHANGELOG.md b/CHANGELOG.md index 734c89ebf..8ebe24d2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The following emojis are used to highlight certain changes: ### Added - `blockservice` now has `ContextWithSession` and `EmbedSessionInContext` functions, which allows to embed a session in a context. Future calls to `BlockGetter.GetBlock`, `BlockGetter.GetBlocks` and `NewSession` will use the session in the context. +- `blockservice` now has `WithContentBlocker` option which allows to filter Add and Get requests by CID. ### Changed diff --git a/blockservice/blockservice.go b/blockservice/blockservice.go index 7733788ec..d8a6f5e96 100644 --- a/blockservice/blockservice.go +++ b/blockservice/blockservice.go @@ -71,10 +71,24 @@ type BoundedBlockService interface { Allowlist() verifcid.Allowlist } +// Blocker returns err != nil if the CID is disallowed to be fetched or stored in blockservice. +// It returns an error so error messages could be passed. +type Blocker func(cid.Cid) error + +// BlockedBlockService is a Blockservice bounded via an arbitrary cid [Blocker]. +type BlockedBlockService interface { + BlockService + + // Blocker might return [nil], then no blocking is to be done. + Blocker() Blocker +} + var _ BoundedBlockService = (*blockService)(nil) +var _ BlockedBlockService = (*blockService)(nil) type blockService struct { allowlist verifcid.Allowlist + blocker Blocker blockstore blockstore.Blockstore exchange exchange.Interface // If checkFirst is true then first check that a block doesn't @@ -99,6 +113,13 @@ func WithAllowlist(allowlist verifcid.Allowlist) Option { } } +// WithContentBlocker allows to filter what blocks can be fetched or added to the blockservice. +func WithContentBlocker(blocker Blocker) Option { + return func(bs *blockService) { + bs.blocker = blocker + } +} + // New creates a BlockService with given datastore instance. func New(bs blockstore.Blockstore, exchange exchange.Interface, opts ...Option) BlockService { if exchange == nil { @@ -141,6 +162,10 @@ func (s *blockService) Allowlist() verifcid.Allowlist { return s.allowlist } +func (s *blockService) Blocker() Blocker { + return s.blocker +} + // NewSession creates a new session that allows for // controlled exchange of wantlists to decrease the bandwidth overhead. // If the current exchange is a SessionExchange, a new exchange @@ -171,6 +196,13 @@ func (s *blockService) AddBlock(ctx context.Context, o blocks.Block) error { if err != nil { return err } + + if s.blocker != nil { + if err := s.blocker(c); err != nil { + return err + } + } + if s.checkFirst { if has, err := s.blockstore.Has(ctx, c); has || err != nil { return err @@ -198,10 +230,17 @@ func (s *blockService) AddBlocks(ctx context.Context, bs []blocks.Block) error { // hash security for _, b := range bs { - err := verifcid.ValidateCid(s.allowlist, b.Cid()) + c := b.Cid() + err := verifcid.ValidateCid(s.allowlist, c) if err != nil { return err } + + if s.blocker != nil { + if err := s.blocker(c); err != nil { + return err + } + } } var toput []blocks.Block if s.checkFirst { @@ -261,6 +300,12 @@ func getBlock(ctx context.Context, c cid.Cid, bs BlockService, fetchFactory func return nil, err } + if blocker := grabBlockerFromBlockservice(bs); blocker != nil { + if err := blocker(c); err != nil { + return nil, err + } + } + blockstore := bs.Blockstore() block, err := blockstore.Get(ctx, c) @@ -320,6 +365,7 @@ func getBlocks(ctx context.Context, ks []cid.Cid, blockservice BlockService, fet defer close(out) allowlist := grabAllowlistFromBlockservice(blockservice) + blocker := grabBlockerFromBlockservice(blockservice) var lastAllValidIndex int var c cid.Cid @@ -327,6 +373,12 @@ func getBlocks(ctx context.Context, ks []cid.Cid, blockservice BlockService, fet if err := verifcid.ValidateCid(allowlist, c); err != nil { break } + + if blocker != nil { + if err := blocker(c); err != nil { + break + } + } } if lastAllValidIndex != len(ks) { @@ -335,11 +387,19 @@ func getBlocks(ctx context.Context, ks []cid.Cid, blockservice BlockService, fet copy(ks2, ks[:lastAllValidIndex]) // fast path for already filtered elements for _, c := range ks[lastAllValidIndex:] { // don't rescan already scanned elements // hash security - if err := verifcid.ValidateCid(allowlist, c); err == nil { - ks2 = append(ks2, c) - } else { + if err := verifcid.ValidateCid(allowlist, c); err != nil { logger.Errorf("unsafe CID (%s) passed to blockService.GetBlocks: %s", c, err) + continue + } + + if blocker != nil { + if err := blocker(c); err != nil { + logger.Errorf("blocked CID (%s) passed to blockService.GetBlocks: %s", c, err) + continue + } } + + ks2 = append(ks2, c) } ks = ks2 } @@ -526,3 +586,10 @@ func grabAllowlistFromBlockservice(bs BlockService) verifcid.Allowlist { } return verifcid.DefaultAllowlist } + +func grabBlockerFromBlockservice(bs BlockService) Blocker { + if bbs, ok := bs.(BlockedBlockService); ok { + return bbs.Blocker() + } + return nil +} diff --git a/blockservice/blockservice_test.go b/blockservice/blockservice_test.go index 53fd725f3..a1bd5d934 100644 --- a/blockservice/blockservice_test.go +++ b/blockservice/blockservice_test.go @@ -2,6 +2,7 @@ package blockservice import ( "context" + "errors" "testing" blockstore "github.com/ipfs/boxo/blockstore" @@ -353,3 +354,72 @@ func TestContextSession(t *testing.T) { "session must be deduped in all invocations on the same context", ) } + +func TestBlocker(t *testing.T) { + t.Parallel() + a := assert.New(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + bgen := butil.NewBlockGenerator() + allowed := bgen.Next() + notAllowed := bgen.Next() + + var disallowed = errors.New("disallowed") + + bs := blockstore.NewBlockstore(dssync.MutexWrap(ds.NewMapDatastore())) + service := New(bs, nil, WithContentBlocker(func(c cid.Cid) error { + if c == notAllowed.Cid() { + return disallowed + } + return nil + })) + + // try putting + a.NoError(service.AddBlock(ctx, allowed)) + has, err := bs.Has(ctx, allowed.Cid()) + a.NoError(err) + a.True(has, "block was not added even tho it is not blocked") + a.NoError(service.DeleteBlock(ctx, allowed.Cid())) + + a.ErrorIs(service.AddBlock(ctx, notAllowed), disallowed) + has, err = bs.Has(ctx, notAllowed.Cid()) + a.NoError(err) + a.False(has, "block was added even tho it is blocked") + + a.NoError(service.AddBlocks(ctx, []blocks.Block{allowed})) + has, err = bs.Has(ctx, allowed.Cid()) + a.NoError(err) + a.True(has, "block was not added even tho it is not blocked") + a.NoError(service.DeleteBlock(ctx, allowed.Cid())) + + a.ErrorIs(service.AddBlocks(ctx, []blocks.Block{notAllowed}), disallowed) + has, err = bs.Has(ctx, notAllowed.Cid()) + a.NoError(err) + a.False(has, "block was added even tho it is blocked") + + // now try fetch + a.NoError(bs.Put(ctx, allowed)) + a.NoError(bs.Put(ctx, notAllowed)) + + block, err := service.GetBlock(ctx, allowed.Cid()) + a.NoError(err) + a.Equal(block.RawData(), allowed.RawData()) + + _, err = service.GetBlock(ctx, notAllowed.Cid()) + a.ErrorIs(err, disallowed) + + var gotAllowed bool + for block := range service.GetBlocks(ctx, []cid.Cid{allowed.Cid(), notAllowed.Cid()}) { + switch block.Cid() { + case allowed.Cid(): + gotAllowed = true + case notAllowed.Cid(): + t.Error("got disallowed block") + default: + t.Fatalf("got unrelated block: %s", block.Cid()) + } + } + a.True(gotAllowed, "did not got allowed block") +}