Skip to content

Commit

Permalink
fix(rows): doesn't close rows on Scan (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Dec 12, 2024
1 parent c243137 commit fe93ae7
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 27 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ jobs:
with:
go-version: ^1.21
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
uses: golangci/golangci-lint-action@v6
with:
version: v1.54
version: v1.61
unit-tests:
name: Unit Tests
runs-on: ubuntu-latest
Expand Down
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.5.1]
## [1.5.2] - 2024-12-12
- fix(rows): don't close rows in rows.Scan (#49)

## [1.5.1] - 2024-06-12
- !fix(orderby): use BuildOption instead of allowedColumns (#46)
- feat(string): added nullable String/Null for sql/json (#47)

Expand Down
30 changes: 16 additions & 14 deletions client_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ func TestStmt(t *testing.T) {
d, err := sql.Open("sqlite3", "file::memory:")
require.NoError(t, err)

_, err = d.Exec("CREATE TABLE `rows` (`id` int , `status` tinyint,`email` varchar(50),`passwd` varchar(120), `salt` varchar(45), `created` DATETIME, PRIMARY KEY (`id`))")
_, err = d.Exec("CREATE TABLE `rows_stmt` (`id` int , `status` tinyint,`email` varchar(50),`passwd` varchar(120), `salt` varchar(45), `created` DATETIME, PRIMARY KEY (`id`))")
require.NoError(t, err)

now := time.Now()

_, err = d.Exec("INSERT INTO `rows`(`id`,`status`,`email`,`passwd`,`salt`,`created`) VALUES(1, 1,'[email protected]','1xxxx','1zzzz', ?)", now)
_, err = d.Exec("INSERT INTO `rows_stmt`(`id`,`status`,`email`,`passwd`,`salt`,`created`) VALUES(1, 1,'[email protected]','1xxxx','1zzzz', ?)", now)
require.NoError(t, err)
_, err = d.Exec("INSERT INTO `rows`(`id`,`status`,`email`,`passwd`,`salt`) VALUES(2, 2,'[email protected]','2xxxx','2zzzz')")
_, err = d.Exec("INSERT INTO `rows_stmt`(`id`,`status`,`email`,`passwd`,`salt`) VALUES(2, 2,'[email protected]','2xxxx','2zzzz')")
require.NoError(t, err)

_, err = d.Exec("INSERT INTO `rows`(`id`,`status`,`email`,`passwd`,`salt`) VALUES(3, 3,'[email protected]','3xxxx','3zzzz')")
_, err = d.Exec("INSERT INTO `rows_stmt`(`id`,`status`,`email`,`passwd`,`salt`) VALUES(3, 3,'[email protected]','3xxxx','3zzzz')")
require.NoError(t, err)

_, err = d.Exec("INSERT INTO `rows`(`id`) VALUES(4)")
_, err = d.Exec("INSERT INTO `rows_stmt`(`id`) VALUES(4)")
require.NoError(t, err)

stmtMaxIdleTime := StmtMaxIdleTime
Expand All @@ -51,7 +51,7 @@ func TestStmt(t *testing.T) {
}

for i := 0; i < 100; i++ {
rows, err := db.Query("SELECT * FROM rows WHERE id<?", 4)
rows, err := db.Query("SELECT * FROM rows_stmt WHERE id<?", 4)
require.NoError(t, err)
var users []user
err = rows.Bind(&users)
Expand All @@ -78,7 +78,7 @@ func TestStmt(t *testing.T) {

for i := 0; i < 100; i++ {

row := db.QueryRow("SELECT * FROM rows WHERE id=?", 1)
row := db.QueryRow("SELECT * FROM rows_stmt WHERE id=?", 1)
require.NoError(t, row.err)
var user user
err = row.Bind(&user)
Expand All @@ -95,14 +95,14 @@ func TestStmt(t *testing.T) {
run: func(t *testing.T) {
for i := 0; i < 100; i++ {

result, err := db.Exec("INSERT INTO `rows`(`id`) VALUES(?)", i+100)
result, err := db.Exec("INSERT INTO `rows_stmt`(`id`) VALUES(?)", i+100)
require.NoError(t, err)
rows, err := result.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), rows)
}

rows, err := db.Query("SELECT id FROM rows WHERE id>=100 order by id")
rows, err := db.Query("SELECT id FROM rows_stmt WHERE id>=100 order by id")
require.NoError(t, err)
var list [][]int
err = rows.Bind(&list)
Expand All @@ -117,7 +117,7 @@ func TestStmt(t *testing.T) {
{
name: "stmt_reuse_should_work_in_exec",
run: func(t *testing.T) {
q := "INSERT INTO `rows`(`id`,`status`) VALUES(?, ?)"
q := "INSERT INTO `rows_stmt`(`id`,`status`) VALUES(?, ?)"

result, err := db.Exec(q, 200, 0)
require.NoError(t, err)
Expand Down Expand Up @@ -147,7 +147,7 @@ func TestStmt(t *testing.T) {
name: "stmt_reuse_should_work_in_rows_scan",
run: func(t *testing.T) {
var id int
q := "SELECT id, 'rows_scan' as reuse FROM rows WHERE id = ?"
q := "SELECT id, 'rows_scan' as reuse FROM rows_stmt WHERE id = ?"
rows, err := db.Query(q, 200)
require.NoError(t, err)

Expand All @@ -164,6 +164,8 @@ func TestStmt(t *testing.T) {
require.True(t, s.isUsing)

rows.Scan(&id) // nolint: errcheck
require.True(t, s.isUsing)
rows.Close()
require.False(t, s.isUsing)

db.closeStaleStmt()
Expand All @@ -181,7 +183,7 @@ func TestStmt(t *testing.T) {
ID int
}

q := "SELECT id, 'rows_bind' as reuse FROM rows WHERE id = ?"
q := "SELECT id, 'rows_bind' as reuse FROM rows_stmt WHERE id = ?"
rows, err := db.Query(q, 200)
require.NoError(t, err)

Expand Down Expand Up @@ -212,7 +214,7 @@ func TestStmt(t *testing.T) {
name: "stmt_reuse_should_work_in_row_scan",
run: func(t *testing.T) {
var id int
q := "SELECT id, 'row_scan' as reuse FROM rows WHERE id = ?"
q := "SELECT id, 'row_scan' as reuse FROM rows_stmt WHERE id = ?"
row := db.QueryRow(q, 200)
require.NoError(t, err)

Expand Down Expand Up @@ -245,7 +247,7 @@ func TestStmt(t *testing.T) {
var r struct {
ID int
}
q := "SELECT id, 'row_bind' as reuse FROM rows WHERE id = ?"
q := "SELECT id, 'row_bind' as reuse FROM rows_stmt WHERE id = ?"
row, err := db.Query(q, 200)
require.NoError(t, err)

Expand Down
14 changes: 7 additions & 7 deletions row.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,6 @@ func (r *Row) Bind(dest any) error {
return r.err
}

if !r.rows.Next() {
if err := r.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}

v := reflect.ValueOf(dest)

if v.Kind() != reflect.Pointer {
Expand All @@ -95,6 +88,13 @@ func (r *Row) Bind(dest any) error {
}

var err error
if !r.rows.Next() {
err = r.rows.Err()
if err != nil {
return err
}
return sql.ErrNoRows
}

cols, err := getColumns(r.query, r.rows)
if err != nil {
Expand Down
25 changes: 24 additions & 1 deletion row_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (cb *customBinder) Bind(_ reflect.Value, columns []string) []any {
return values
}

func TestRowBind(t *testing.T) {
func TestRow(t *testing.T) {

d, err := sql.Open("sqlite3", "file::memory:?cache=shared")
require.NoError(t, err)
Expand All @@ -91,6 +91,29 @@ func TestRowBind(t *testing.T) {
name string
run func(t *testing.T)
}{
{
name: "close_should_always_work",
run: func(*testing.T) {
var row *Row
row.Close()
row = &Row{}
row.Close()
},
},
{
name: "bind_only_work_with_non_nil_pointer",
run: func(t *testing.T) {

row := &Row{}
var dest int
err := row.Bind(dest)
require.ErrorIs(t, err, ErrMustPointer)

var dest2 *int
err = row.Bind(dest2)
require.ErrorIs(t, err, ErrMustNotNilPointer)
},
},
{
name: "full_columns_should_work",
run: func(t *testing.T) {
Expand Down
1 change: 0 additions & 1 deletion rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ func (r *Rows) Close() error {
}

func (r *Rows) Scan(dest ...any) error {
defer r.Close()
return r.Rows.Scan(dest...)
}

Expand Down
53 changes: 52 additions & 1 deletion rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestRowsBind(t *testing.T) {
func TestRows(t *testing.T) {

d, err := sql.Open("sqlite3", "file::memory:")
require.NoError(t, err)
Expand All @@ -36,6 +36,57 @@ func TestRowsBind(t *testing.T) {
name string
run func(t *testing.T)
}{
{
name: "close_should_always_work",
run: func(*testing.T) {

rows := &Rows{}
rows.Close()
},
},
{
name: "bind_only_work_with_non_nil_pointer",
run: func(t *testing.T) {

rows := &Rows{}
var dest int
err := rows.Bind(dest)
require.ErrorIs(t, err, ErrMustPointer)

var dest2 *int
err = rows.Bind(dest2)
require.ErrorIs(t, err, ErrMustNotNilPointer)
},
},
{
name: "scan_on_rows_should_work",
run: func(t *testing.T) {

rows, err := db.Query("SELECT id,email FROM rows WHERE id<4")
require.NoError(t, err)

defer rows.Close()

var id int
var email string

rows.Next()
err = rows.Scan(&id, &email)
require.NoError(t, err)
require.Equal(t, 1, id)
require.Equal(t, "[email protected]", email)
rows.Next()
err = rows.Scan(&id, &email)
require.NoError(t, err)
require.Equal(t, 2, id)
require.Equal(t, "[email protected]", email)
rows.Next()
err = rows.Scan(&id, &email)
require.NoError(t, err)
require.Equal(t, 3, id)
require.Equal(t, "[email protected]", email)
},
},
{
name: "bind_slice_of_struct_should_work",
run: func(t *testing.T) {
Expand Down

0 comments on commit fe93ae7

Please sign in to comment.