Skip to content

Commit

Permalink
feat(db): added Duration support in Scanner/Valuer (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Mar 19, 2024
1 parent 572d453 commit 8af328c
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 7 deletions.
4 changes: 2 additions & 2 deletions bitbool.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type BitBool bool

// Value implements the driver.Valuer interface,
// and turns the BitBool into a bit field (BIT(1)) for MySQL storage.
func (b BitBool) Value() (driver.Value, error) {
func (b BitBool) Value() (driver.Value, error) { // skipcq: GO-W1029
if b {
return []byte{1}, nil
} else {
Expand All @@ -20,7 +20,7 @@ func (b BitBool) Value() (driver.Value, error) {

// Scan implements the sql.Scanner interface,
// and turns the bit field incoming from MySQL into a BitBool
func (b *BitBool) Scan(src interface{}) error {
func (b *BitBool) Scan(src interface{}) error { // skipcq: GO-W1029
if src == nil {
return nil
}
Expand Down
47 changes: 47 additions & 0 deletions duration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package sqle

import (
"database/sql/driver"
"errors"
"time"
)

type Duration time.Duration

func (d Duration) Duration() time.Duration { // skipcq: GO-W1029
return time.Duration(d)
}

// Value implements the driver.Valuer interface,
// and turns the Duration into a VARCHAR field for MySQL storage.
func (d Duration) Value() (driver.Value, error) { // skipcq: GO-W1029
return time.Duration(d).String(), nil
}

// Scan implements the sql.Scanner interface,
// and turns the VARCHAR field incoming from MySQL into a Duration
func (d *Duration) Scan(src interface{}) error { // skipcq: GO-W1029
if src == nil {
return nil
}

var val string

switch v := src.(type) {
case []byte:
val = string(v)
case string:
val = v
default:
return errors.New("bad duration type assertion")
}

td, err := time.ParseDuration(val)
if err != nil {
return err
}

*d = Duration(td)

return nil
}
79 changes: 79 additions & 0 deletions duration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package sqle

import (
"database/sql"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestDuration(t *testing.T) {
d, err := sql.Open("sqlite3", "file::memory:")
require.NoError(t, err)

_, err = d.Exec("CREATE TABLE `users` (`id` id NOT NULL,`ttl` VARCHAR(20), `b_ttl` VARBINARY, PRIMARY KEY (`id`))")
require.NoError(t, err)

d10 := Duration(10 * time.Second)
d11 := Duration(11 * time.Second)
d12 := Duration(12 * time.Second)
d13 := Duration(13 * time.Second)

result, err := d.Exec("INSERT INTO `users`(`id`, `ttl`,`b_ttl`) VALUES(?, ?, ?)", 10, d10, d10)
require.NoError(t, err)

rows, err := result.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), rows)

result, err = d.Exec("INSERT INTO `users`(`id`, `ttl`) VALUES(?, ?)", 11, d11)
require.NoError(t, err)

rows, err = result.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), rows)

result, err = d.Exec("INSERT INTO `users`(`id`, `ttl`) VALUES(?, ?)", 12, d12)
require.NoError(t, err)

rows, err = result.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), rows)

result, err = d.Exec("INSERT INTO `users`(`id`, `ttl`) VALUES(?, ?)", 13, d13)
require.NoError(t, err)

rows, err = result.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), rows)

var b10 Duration
var b_b10 Duration
err = d.QueryRow("SELECT `ttl`, `b_ttl` FROM `users` WHERE id=?", 10).Scan(&b10, &b_b10)
require.NoError(t, err)

require.Equal(t, d10.Duration(), b10.Duration())
require.Equal(t, d10.Duration(), b_b10.Duration())

var b11 Duration
var b_b11 Duration
err = d.QueryRow("SELECT `ttl`, `b_ttl` FROM `users` WHERE id=?", 11).Scan(&b11, &b_b11)
require.NoError(t, err)

require.EqualValues(t, d11, b11)
require.Empty(t, b_b11)

var b12 Duration
err = d.QueryRow("SELECT `ttl` FROM `users` WHERE id=?", 12).Scan(&b12)
require.NoError(t, err)

require.EqualValues(t, d12, b12)

var b13 Duration
err = d.QueryRow("SELECT `ttl` FROM `users` WHERE id=?", 13).Scan(&b13)
require.NoError(t, err)

require.EqualValues(t, d13, b13)

}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.22
github.com/rs/zerolog v1.32.0
github.com/stretchr/testify v1.9.0
github.com/yaitoo/async v1.0.3
github.com/yaitoo/async v1.0.4
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yaitoo/async v1.0.3 h1:zBlE07xt/EkRqfrUU/zJSfIcsZJgnUDCOphRAvnT/Kk=
github.com/yaitoo/async v1.0.3/go.mod h1:IpSO7Ei7AxiqLxFqDjN4rJaVlt8wm4ZxMXyyQaWmM1g=
github.com/yaitoo/async v1.0.4 h1:u+SWuJcSckgBOcMjMYz9IviojeCatDrdni3YNGLCiHY=
github.com/yaitoo/async v1.0.4/go.mod h1:IpSO7Ei7AxiqLxFqDjN4rJaVlt8wm4ZxMXyyQaWmM1g=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down
2 changes: 1 addition & 1 deletion migrate/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ func (m *Migrator) Rotate(ctx context.Context) error {
}

var week int
_, week = now.ISOWeek() //1-53 week
_, week = now.ISOWeek() // 1-53 week

next := now.AddDate(0, 0, 7)
_, nextWeek := next.ISOWeek()
Expand Down
2 changes: 1 addition & 1 deletion queryer_mapr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ func TestQueryLimit(t *testing.T) {
return q.QueryLimit(context.Background(), New().
Select("users", "id").
SQL("ORDER BY id DESC"), func(i, j MRUser) bool {
//DESC
// DESC
return j.ID < i.ID
}, 16)
},
Expand Down

0 comments on commit 8af328c

Please sign in to comment.