From 8809383e767c0eb19906a7b242e817ce21e6eb65 Mon Sep 17 00:00:00 2001 From: Peter Broadhurst Date: Sun, 31 Dec 2023 19:12:08 -0500 Subject: [PATCH] Add error return to query modifier Signed-off-by: Peter Broadhurst --- mocks/crudmocks/crud.go | 4 ++-- pkg/dbsql/crud.go | 18 +++++++++++------- pkg/dbsql/crud_test.go | 34 ++++++++++++++++++++++++++++------ pkg/dbsql/database.go | 6 ++++-- pkg/dbsql/database_test.go | 14 ++++++++++++-- 5 files changed, 57 insertions(+), 19 deletions(-) diff --git a/mocks/crudmocks/crud.go b/mocks/crudmocks/crud.go index 65de706..23aa474 100644 --- a/mocks/crudmocks/crud.go +++ b/mocks/crudmocks/crud.go @@ -374,7 +374,7 @@ func (_m *CRUD[T]) InsertMany(ctx context.Context, instances []T, allowPartialSu } // ModifyQuery provides a mock function with given fields: modifier -func (_m *CRUD[T]) ModifyQuery(modifier func(squirrel.SelectBuilder) squirrel.SelectBuilder) dbsql.CRUDQuery[T] { +func (_m *CRUD[T]) ModifyQuery(modifier func(squirrel.SelectBuilder) (squirrel.SelectBuilder, error)) dbsql.CRUDQuery[T] { ret := _m.Called(modifier) if len(ret) == 0 { @@ -382,7 +382,7 @@ func (_m *CRUD[T]) ModifyQuery(modifier func(squirrel.SelectBuilder) squirrel.Se } var r0 dbsql.CRUDQuery[T] - if rf, ok := ret.Get(0).(func(func(squirrel.SelectBuilder) squirrel.SelectBuilder) dbsql.CRUDQuery[T]); ok { + if rf, ok := ret.Get(0).(func(func(squirrel.SelectBuilder) (squirrel.SelectBuilder, error)) dbsql.CRUDQuery[T]); ok { r0 = rf(modifier) } else { if ret.Get(0) != nil { diff --git a/pkg/dbsql/crud.go b/pkg/dbsql/crud.go index c6e82a0..94c7bcc 100644 --- a/pkg/dbsql/crud.go +++ b/pkg/dbsql/crud.go @@ -172,14 +172,14 @@ func (c *CrudBase[T]) NewUpdateBuilder(ctx context.Context) ffapi.UpdateBuilder func (c *CrudBase[T]) ModifyQuery(newModifier QueryModifier) CRUDQuery[T] { cModified := *c originalModifier := cModified.ReadQueryModifier - cModified.ReadQueryModifier = func(sb sq.SelectBuilder) sq.SelectBuilder { + cModified.ReadQueryModifier = func(sb sq.SelectBuilder) (_ sq.SelectBuilder, err error) { if originalModifier != nil { - sb = originalModifier(sb) + sb, err = originalModifier(sb) } - if newModifier != nil { - sb = newModifier(sb) + if err == nil && newModifier != nil { + sb, err = newModifier(sb) } - return sb + return sb, err } return &cModified } @@ -595,7 +595,9 @@ func (c *CrudBase[T]) GetByID(ctx context.Context, id string, getOpts ...GetOpti From(tableFrom). Where(c.idFilter(id)) if c.ReadQueryModifier != nil { - query = c.ReadQueryModifier(query) + if query, err = c.ReadQueryModifier(query); err != nil { + return c.NilValue(), err + } } rows, _, err := c.DB.Query(ctx, c.Table, query) @@ -701,7 +703,9 @@ func (c *CrudBase[T]) getManyScoped(ctx context.Context, tableFrom string, fi *f return nil, nil, err } if c.ReadQueryModifier != nil { - query = c.ReadQueryModifier(query) + if query, err = c.ReadQueryModifier(query); err != nil { + return nil, nil, err + } } rows, tx, err := c.DB.Query(ctx, c.Table, query) diff --git a/pkg/dbsql/crud_test.go b/pkg/dbsql/crud_test.go index bfc71e7..3e83be7 100644 --- a/pkg/dbsql/crud_test.go +++ b/pkg/dbsql/crud_test.go @@ -288,8 +288,8 @@ func newLinkableCollection(db *Database, ns string) *CrudBase[*TestLinkable] { "description": "desc", "crud": "crud_id", }, - ReadQueryModifier: func(query sq.SelectBuilder) sq.SelectBuilder { - return query.LeftJoin("crudables AS c ON c.id = l.crud_id") + ReadQueryModifier: func(query sq.SelectBuilder) (sq.SelectBuilder, error) { + return query.LeftJoin("crudables AS c ON c.id = l.crud_id"), nil }, DefaultSort: func() []interface{} { // Return an empty list @@ -399,11 +399,11 @@ func TestCRUDWithDBEnd2End(t *testing.T) { checkEqualExceptTimes(t, *c1, *c1copy) // Check we get it back with custom modifiers - collection.ReadQueryModifier = func(sb sq.SelectBuilder) sq.SelectBuilder { - return sb.Where(sq.Eq{"ns": "ns1"}) + collection.ReadQueryModifier = func(sb sq.SelectBuilder) (sq.SelectBuilder, error) { + return sb.Where(sq.Eq{"ns": "ns1"}), nil } - c1copy, err = iCrud.ModifyQuery(func(sb sq.SelectBuilder) sq.SelectBuilder { - return sb.Where(sq.Eq{"field1": "hello1"}) + c1copy, err = iCrud.ModifyQuery(func(sb sq.SelectBuilder) (sq.SelectBuilder, error) { + return sb.Where(sq.Eq{"field1": "hello1"}), nil }).GetByName(ctx, *c1.Name) assert.NoError(t, err) checkEqualExceptTimes(t, *c1, *c1copy) @@ -895,6 +895,17 @@ func TestGetByIDScanFail(t *testing.T) { assert.NoError(t, mock.ExpectationsWereMet()) } +func TestGetByIDReqQueryModifierFail(t *testing.T) { + db, mock := NewMockProvider().UTInit() + tc := newCRUDCollection(&db.Database, "ns1") + tc.ReadQueryModifier = func(sb sq.SelectBuilder) (sq.SelectBuilder, error) { + return sb, fmt.Errorf("pop") + } + _, err := tc.GetByID(context.Background(), fftypes.NewUUID().String()) + assert.Regexp(t, "pop", err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + func TestGetByNameNoNameSemantics(t *testing.T) { db, _ := NewMockProvider().UTInit() tc := newLinkableCollection(&db.Database, "ns1") @@ -958,6 +969,17 @@ func TestGetManySelectFail(t *testing.T) { assert.NoError(t, mock.ExpectationsWereMet()) } +func TestGetManyReadModifierFail(t *testing.T) { + db, mock := NewMockProvider().UTInit() + tc := newCRUDCollection(&db.Database, "ns1") + tc.ReadQueryModifier = func(sb sq.SelectBuilder) (sq.SelectBuilder, error) { + return sb, fmt.Errorf("pop") + } + _, _, err := tc.GetMany(context.Background(), CRUDableQueryFactory.NewFilter(context.Background()).And()) + assert.Regexp(t, "pop", err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + func TestGetByManyScanFail(t *testing.T) { db, mock := NewMockProvider().UTInit() tc := newCRUDCollection(&db.Database, "ns1") diff --git a/pkg/dbsql/database.go b/pkg/dbsql/database.go index e74c242..40b372e 100644 --- a/pkg/dbsql/database.go +++ b/pkg/dbsql/database.go @@ -42,7 +42,7 @@ type Database struct { sequenceColumn string } -type QueryModifier = func(sq.SelectBuilder) sq.SelectBuilder +type QueryModifier = func(sq.SelectBuilder) (sq.SelectBuilder, error) // PreCommitAccumulator is a structure that can accumulate state during // the transaction, then has a function that is called just before commit. @@ -241,7 +241,9 @@ func (s *Database) CountQuery(ctx context.Context, table string, tx *TXWrapper, } q := sq.Select(fmt.Sprintf("COUNT(%s)", countExpr)).From(table).Where(fop) if qm != nil { - q = qm(q) + if q, err = qm(q); err != nil { + return -1, err + } } sqlQuery, args, err := q.PlaceholderFormat(s.features.PlaceholderFormat).ToSql() if err != nil { diff --git a/pkg/dbsql/database_test.go b/pkg/dbsql/database_test.go index 84c0e41..9a25456 100644 --- a/pkg/dbsql/database_test.go +++ b/pkg/dbsql/database_test.go @@ -545,14 +545,24 @@ func TestCountQueryWithExpr(t *testing.T) { func TestCountQueryWithModifier(t *testing.T) { s, mdb := NewMockProvider().UTInit() mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow(10)) - qm := func(sb sq.SelectBuilder) sq.SelectBuilder { - return sb.Where(sq.Eq{"col1": "val1"}) + qm := func(sb sq.SelectBuilder) (sq.SelectBuilder, error) { + return sb.Where(sq.Eq{"col1": "val1"}), nil } _, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, qm, "") assert.NoError(t, err) assert.NoError(t, mdb.ExpectationsWereMet()) } +func TestCountQueryWithModifierErr(t *testing.T) { + s, mdb := NewMockProvider().UTInit() + mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow(10)) + qm := func(sb sq.SelectBuilder) (sq.SelectBuilder, error) { + return sb, fmt.Errorf("pop") + } + _, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, qm, "") + assert.Regexp(t, "pop", err) +} + func TestQueryResSwallowError(t *testing.T) { s, _ := NewMockProvider().UTInit() res := s.QueryRes(context.Background(), "table1", nil, sq.Insert("wrong"), nil, &ffapi.FilterInfo{