diff --git a/CHANGELOG.md b/CHANGELOG.md index bfcd607..2a1b0db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - added `BitBool` for mysql bit type (#11) +- added `sharding` feature (#12) + +### Fixed +- fixed parameterized placeholder for postgresql(#12) ## [1.1.0] - 2024-02-13 ### Added diff --git a/README.md b/README.md index 584dc6d..1ecb546 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,19 @@ You’ll find the SQLE package useful if you’re not a fan of full-featured ORM ## Tutorials > All examples on https://go.dev/doc/tutorial/database-access can directly work with `sqle.DB` instance. +> + +### Install SQLE +- install latest commit from `main` branch +``` +go get github.com/yaitoo/sqle@main +``` + +- install latest release +``` +go get github.com/yaitoo/sqle@latest +``` + ### Connecting to a Database SQLE directly connects to a database by `sql.DB` instance. ``` @@ -235,7 +248,7 @@ func albumsByArtist(name string) ([]Album, error) { // An albums slice to hold data from returned rows. var albums []Album - cmd := sql.New("SELECT * FROM album").Where(). + cmd := sql.New().Select("album").Where(). If(name != "").And("artist = {artist}"). Param("artist",name) @@ -291,7 +304,7 @@ func albumByID(id int64) (Album, error) { func albumByID(id int64) (Album, error) { // An album to hold data from the returned row. var alb Album - cmd := sqle.New("SELECT * FROM album"). + cmd := sqle.New().Select("album"). Where("id = {id}"). Param("id",id) @@ -383,7 +396,7 @@ func deleteAlbumByID(id int64) error { - delete album by named sql statement ``` func deleteAlbumByID(id int64) error { - _, err := db.ExecBuilder(context.TODO(), sqle.New("DELETE FROM album WHERE id = {id}"). + _, err := db.ExecBuilder(context.TODO(), sqle.New().Delete("album").Where("id = {id}"). Param("id",id)) return err @@ -395,7 +408,7 @@ perform a set of operations within a transaction ``` func deleteAlbums(ids []int64) error { - return db.Transaction(ctx, &sql.TxOptions{}, func(tx *sqle.Tx) error { + return db.Transaction(ctx, &sql.TxOptions{}, func(ctx context.Context,tx *sqle.Tx) error { var err error for _, id := range ids { _, err = tx.Exec("DELETE FROM album WHERE id=?",id) diff --git a/sharding/README.md b/sharding/README.md new file mode 100644 index 0000000..87811cf --- /dev/null +++ b/sharding/README.md @@ -0,0 +1,49 @@ +# sharding + +## sid-64-bit +// +----------+-------------------+------------+----------------+----------------------+---------------+ +// | signed 1 | millis (39) | worker(2) | db-sharding(10)| table-rotate(2) | sequence(10) | +// +----------+-------------------+------------+----------------+----------------------+---------------+ + 39 = 17 years 2 = 4 10=1024 0: none :table 10=1024 + 1: monthly :table-[YYYYMM] + 2: weekly :table-[YYYY0XX] + 3: daily :table-[YYYYMMDD] +- signed(1): sid is always positive number +- millis(39): 2^39 (17years) unix milliseconds since 2024-02-19 00:00:00 +- workers(4): 2^4(16) workers +- db-sharding(10): 2^10 (1024) db instances +- table-rotate(2): 2^2(4) table rotate: none/by year/by month/by day +- sequence(10): 2^10(1024) per milliseconds + +## TPS: + - ID: 1000(ms)*1024(seq)*4 = 4096000 409.6W/s + 1000*1024 = 1024000 102.4W/s + + - DB : + 10 * 1000 = 10000 1W/s + 1024 * 1000 = 1024000 102.4W/s + + 10 * 2000 = 20000 2W/s + 1024 * 2000 = 2048000 204.8W/s + + 10 * 3000 = 30000 3W/s + 1024 * 3000 = 3072000 307.2W/s + +## mysql-benchmark + - https://docs.aws.amazon.com/whitepapers/latest/optimizing-mysql-on-ec2-using-amazon-ebs/mysql-benchmark-observations-and-considerations.html + - https://github.com/MinervaDB/MinervaDB-Sysbench + - https://www.percona.com/blog/assessing-mysql-performance-amongst-aws-options-part-one/ + +## issues +- Overflow capacity + waiting for next microsecond. + +- System Clock Dependency + You should use NTP to keep your system clock accurate. + +- Time move backwards + + if sequence doesn't overflow, let's use last timestamp and next sequence. system clock might moves forward and greater than last timestamp on next id generation + + if sequence overflows, and has to be reset. let's built-in clock to get timestamp till system clock moves forward and greater than built-in clock + +- Built-in clock + record last timestamp in memory/database, increase it when it is requested to send current timestamp instead of system clock \ No newline at end of file diff --git a/sharding/generator.go b/sharding/generator.go new file mode 100644 index 0000000..212e411 --- /dev/null +++ b/sharding/generator.go @@ -0,0 +1,101 @@ +package sharding + +import ( + "sync" + "time" +) + +type Generator struct { + sync.Mutex + _ noCopy // nolint: unused + + workerID int8 + databaseTotal int16 + tableRotate TableRotate + now func() time.Time + + lastMillis int64 + nextSequence int16 + nextDatabaseID int16 +} + +func New(options ...Option) *Generator { + g := &Generator{ + now: time.Now, + databaseTotal: 1, + tableRotate: None, + workerID: acquireWorkerID(), + } + for _, option := range options { + option(g) + } + return g +} + +func (g *Generator) Next() int64 { + g.Lock() + + defer func() { + g.nextSequence++ + g.Unlock() + }() + + nowMillis := g.now().UnixMilli() + if nowMillis < g.lastMillis { + if g.nextSequence > MaxSequence { + // time move backwards,and sequence overflows capacity, waiting system clock to move forward + g.nextSequence = 0 + nowMillis = g.tillNextMillis() + } else { + // time move backwards,but sequence doesn't overflow capacity, use Built-in clock to move forward + nowMillis = g.moveNextMillis() + } + } + + // sequence overflows capacity + if g.nextSequence > MaxSequence { + if nowMillis == g.lastMillis { + nowMillis = g.tillNextMillis() + } + + g.nextSequence = 0 + } + + g.lastMillis = nowMillis + + return Build(nowMillis, g.workerID, g.getNextDatabaseID(), g.tableRotate, g.nextSequence) + +} + +func (g *Generator) getNextDatabaseID() int16 { + if g.databaseTotal <= 1 { + return 0 + } + + defer func() { + g.nextDatabaseID++ + }() + + if g.nextDatabaseID < g.databaseTotal { + return g.nextDatabaseID + } + + g.nextDatabaseID = 0 + return 0 +} + +func (g *Generator) tillNextMillis() int64 { + lastMillis := g.now().UnixMilli() + for { + if lastMillis > g.lastMillis { + break + } + + lastMillis = g.now().UnixMilli() + } + + return lastMillis +} +func (g *Generator) moveNextMillis() int64 { + return g.lastMillis + 1 +} diff --git a/sharding/generator_test.go b/sharding/generator_test.go new file mode 100644 index 0000000..b16e512 --- /dev/null +++ b/sharding/generator_test.go @@ -0,0 +1,283 @@ +package sharding + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestGenerator(t *testing.T) { + tests := []struct { + name string + new func() *Generator + assert func(t *testing.T, gen *Generator) + }{ + { + name: "sequence_should_work", + new: func() *Generator { + g := New(WithTimeNow(func() time.Time { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC) + })) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 0, 0, None, 0) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 0, 0, None, 1) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 0, 0, None, 2) + require.Equal(t, want, id) + + }, + }, + { + name: "worker_id_should_work", + new: func() *Generator { + g := New(WithTimeNow(func() time.Time { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC) + }), WithWorkerID(1)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, None, 0) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, None, 1) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, None, 2) + require.Equal(t, want, id) + + }, + }, + { + name: "database_id_should_work", + new: func() *Generator { + g := New(WithTimeNow(func() time.Time { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC) + }), WithWorkerID(1), WithDatabase(3)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, None, 0) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 1, None, 1) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 2, None, 2) + require.Equal(t, want, id) + + }, + }, + { + name: "database_id_should_reset", + new: func() *Generator { + g := New(WithTimeNow(func() time.Time { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC) + }), WithWorkerID(1), WithDatabase(2)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, None, 0) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 1, None, 1) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, None, 2) + require.Equal(t, want, id) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 1, None, 3) + require.Equal(t, want, id) + + }, + }, + { + name: "monthly_rotate_should_work", + new: func() *Generator { + g := New(WithTimeNow(func() time.Time { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC) + }), WithWorkerID(1), WithDatabase(3), WithTableRotate(Monthly)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, Monthly, 0) + require.Equal(t, want, id) + + md := Parse(id) + require.Equal(t, "202402", md.RotateName()) + }, + }, + { + name: "weekly_rotate_should_work", + new: func() *Generator { + g := New(WithTimeNow(func() time.Time { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC) + }), WithWorkerID(1), WithDatabase(3), WithTableRotate(Weekly)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, Weekly, 0) + require.Equal(t, want, id) + + md := Parse(id) + require.Equal(t, "2024008", md.RotateName()) + }, + }, + { + name: "daily_rotate_should_work", + new: func() *Generator { + g := New(WithTimeNow(func() time.Time { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC) + }), WithWorkerID(1), WithDatabase(3), WithTableRotate(Daily)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, Daily, 0) + require.Equal(t, want, id) + + md := Parse(id) + require.Equal(t, "20240220", md.RotateName()) + }, + }, + { + name: "sequence_overflows_capacity_should_work", + new: func() *Generator { + i := 0 + g := New(WithTimeNow(func() time.Time { + defer func() { + i++ + }() + + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).Add(time.Duration(i) * time.Millisecond) + + }), WithWorkerID(1), WithTableRotate(Daily)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + gen.nextSequence = MaxSequence + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, Daily, MaxSequence) + require.Equal(t, want, id) + + md := Parse(id) + require.Equal(t, "20240220", md.RotateName()) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).Add(1*time.Millisecond).UnixMilli(), 1, 0, Daily, 0) + require.Equal(t, want, id) + + md = Parse(id) + require.Equal(t, "20240220", md.RotateName()) + }, + }, + { + name: "time_move_backwards_should_work", + new: func() *Generator { + i := 0 + g := New(WithTimeNow(func() time.Time { + defer func() { + i++ + }() + + if i == 1 { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).Add(-1 * time.Millisecond) + } + + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).Add(time.Duration(i) * time.Millisecond) + + }), WithWorkerID(1), WithTableRotate(Daily)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, Daily, 0) + require.Equal(t, want, id) + + md := Parse(id) + require.Equal(t, "20240220", md.RotateName()) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).Add(1*time.Millisecond).UnixMilli(), 1, 0, Daily, 1) + require.Equal(t, want, id) + + md = Parse(id) + require.Equal(t, "20240220", md.RotateName()) + + }, + }, + { + name: "time_move_backwards_and_sequence_overflows_capacity_should_work", + new: func() *Generator { + i := 0 + g := New(WithTimeNow(func() time.Time { + defer func() { + i++ + }() + + if i == 1 { + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).Add(-1 * time.Millisecond) + } + + return time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).Add(time.Duration(i) * time.Millisecond) + + }), WithWorkerID(1), WithTableRotate(Daily)) + + return g + }, + assert: func(t *testing.T, gen *Generator) { + gen.nextSequence = MaxSequence + id := gen.Next() + want := Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).UnixMilli(), 1, 0, Daily, MaxSequence) + require.Equal(t, want, id) + + md := Parse(id) + require.Equal(t, "20240220", md.RotateName()) + + id = gen.Next() + want = Build(time.Date(2024, 2, 20, 0, 0, 0, 0, time.UTC).Add(2*time.Millisecond).UnixMilli(), 1, 0, Daily, 0) + require.Equal(t, want, id) + + md = Parse(id) + require.Equal(t, "20240220", md.RotateName()) + + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := test.new() + test.assert(t, g) + }) + } +} diff --git a/sharding/id.go b/sharding/id.go new file mode 100644 index 0000000..250efad --- /dev/null +++ b/sharding/id.go @@ -0,0 +1,89 @@ +package sharding + +import ( + "fmt" + "time" +) + +const ( + + // TimeEpoch : 2024-2-19 + TimeEpoch int64 = 1708300800000 + // TimeEnd : 2041-2-19 + TimeEnd int64 = 2244844800000 + + // TimeMillisBits milliseconds since TimeEpoch + TimeMillisBits = 39 + // WorkerBits worker id: 0-3 + WorkerBits = 2 + // DatabaseBits database id: 0-1023 + DatabaseBits = 10 + // TableBits table sharding: 0=none/1=yyyyMM/2=yyyy0WW/3=yyyyMMDD + TableBits = 2 + // SequenceBits sequence: 0-1023 + SequenceBits = 10 + + TimeNowShift = WorkerBits + DatabaseBits + TableBits + SequenceBits + WorkerShift = DatabaseBits + TableBits + SequenceBits + DatabaseShift = TableBits + SequenceBits + TableShift = SequenceBits + + MaxSequence int16 = -1 ^ (-1 << SequenceBits) //1023 + MaxTableShard int8 = -1 ^ (-1 << TableBits) + MaxDatabaseID int16 = -1 ^ (-1 << DatabaseBits) + MaxWorkerID int8 = -1 ^ (-1 << WorkerBits) + MaxTimeMillis int64 = -1 ^ (-1 << TimeMillisBits) +) + +type TableRotate int8 + +var ( + None TableRotate = 0 + Monthly TableRotate = 1 + Weekly TableRotate = 2 + Daily TableRotate = 3 +) + +type ID struct { + Time time.Time + ID int64 + TimeMillis int64 + + Sequence int16 + DatabaseID int16 + + WorkerID int8 + TableRotate TableRotate +} + +func (i *ID) RotateName() string { + switch i.TableRotate { + case Daily: + return i.Time.Format("20060102") + case Weekly: + _, week := i.Time.ISOWeek() //1-53 week + return i.Time.Format("2006") + fmt.Sprintf("%03d", week) + case Monthly: + return i.Time.Format("200601") + default: + return "" + } +} + +func Build(timeNow int64, workerID int8, databaseID int16, tr TableRotate, sequence int16) int64 { + return int64(timeNow-TimeEpoch)<>TableShift) & MaxTableShard), + DatabaseID: int16(id>>DatabaseShift) & MaxDatabaseID, + WorkerID: int8(id>>WorkerShift) & MaxWorkerID, + TimeMillis: int64(id>>TimeNowShift)&MaxTimeMillis + TimeEpoch, + } + s.Time = time.UnixMilli(s.TimeMillis).UTC() + + return s +} diff --git a/sharding/id_test.go b/sharding/id_test.go new file mode 100644 index 0000000..a636023 --- /dev/null +++ b/sharding/id_test.go @@ -0,0 +1,98 @@ +package sharding + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestID(t *testing.T) { + tests := []struct { + name string + build func() int64 + timeNow time.Time + workerID int8 + databaseID int16 + tableRotate TableRotate + sequence int16 + orderby bool + }{ + { + name: "build_min_values_should_work", + timeNow: time.UnixMilli(TimeEpoch), + workerID: 0, + databaseID: 0, + tableRotate: None, + sequence: 0, + }, + { + name: "build_max_values_should_work", + timeNow: time.UnixMilli(TimeEnd), + workerID: MaxWorkerID, + databaseID: MaxDatabaseID, + tableRotate: Daily, + sequence: MaxSequence, + }, + { + name: "build_should_work", + timeNow: time.Now(), + workerID: int8(rand.Intn(4)), + databaseID: int16(rand.Intn(1024)), + tableRotate: Weekly, + sequence: int16(rand.Intn(1024)), + }, + { + name: "id_should_orderable", + timeNow: time.Now(), + workerID: 0, + databaseID: 0, + tableRotate: Monthly, + sequence: 0, + orderby: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + id := Build(test.timeNow.UnixMilli(), test.workerID, test.databaseID, test.tableRotate, test.sequence) + + result := Parse(id) + + require.Equal(t, test.timeNow.UnixMilli(), result.Time.UnixMilli()) + require.Equal(t, test.workerID, result.WorkerID) + require.Equal(t, test.databaseID, result.DatabaseID) + require.Equal(t, test.tableRotate, result.TableRotate) + require.Equal(t, test.sequence, result.Sequence) + + switch test.tableRotate { + case None: + require.Equal(t, "", result.RotateName()) + case Monthly: + require.Equal(t, test.timeNow.UTC().Format("200601"), result.RotateName()) + case Weekly: + _, week := test.timeNow.UTC().ISOWeek() + require.Equal(t, test.timeNow.UTC().Format("2006")+fmt.Sprintf("%03d", week), result.RotateName()) + case Daily: + require.Equal(t, test.timeNow.UTC().Format("20060102"), result.RotateName()) + default: + require.Equal(t, "", result.RotateName()) + } + + if test.orderby { + id2 := Build(test.timeNow.UnixMilli(), test.workerID, test.databaseID, test.tableRotate, test.sequence+1) + + id3 := Build(test.timeNow.UnixMilli(), test.workerID+1, test.databaseID, test.tableRotate, test.sequence+2) + + id4 := Build(test.timeNow.Add(1*time.Millisecond).UnixMilli(), test.workerID, test.databaseID, test.tableRotate, test.sequence+3) + + require.Greater(t, id2, id) + require.Greater(t, id3, id2) + require.Greater(t, id4, id3) + } + + }) + } +} diff --git a/sharding/nocopy.go b/sharding/nocopy.go new file mode 100644 index 0000000..be9a08a --- /dev/null +++ b/sharding/nocopy.go @@ -0,0 +1,10 @@ +package sharding + +// nolint: unused +type noCopy struct{} + +// nolint: unused +func (*noCopy) Lock() {} + +// nolint: unused +func (*noCopy) Unlock() {} diff --git a/sharding/option.go b/sharding/option.go new file mode 100644 index 0000000..dcd52f8 --- /dev/null +++ b/sharding/option.go @@ -0,0 +1,37 @@ +package sharding + +import ( + "time" +) + +type Option func(g *Generator) + +func WithWorkerID(i int8) Option { + return func(g *Generator) { + if i >= 0 && i <= MaxWorkerID { + g.workerID = i + } + } +} + +func WithDatabase(total int16) Option { + return func(g *Generator) { + if total >= 0 && total <= MaxDatabaseID { + g.databaseTotal = total + } + } +} + +func WithTableRotate(ts TableRotate) Option { + return func(g *Generator) { + if ts >= None && ts <= Daily { + g.tableRotate = ts + } + } +} + +func WithTimeNow(now func() time.Time) Option { + return func(g *Generator) { + g.now = now + } +} diff --git a/sharding/woker.go b/sharding/woker.go new file mode 100644 index 0000000..c1da974 --- /dev/null +++ b/sharding/woker.go @@ -0,0 +1,20 @@ +package sharding + +import ( + "os" + "strconv" +) + +func acquireWorkerID() int8 { + + i, err := strconv.Atoi(os.Getenv("SQLE_WORKER_ID")) + if err != nil { + return 0 + } + + if i >= 0 && i <= int(MaxWorkerID) { + return int8(i) + } + + return 0 +} diff --git a/sqlbuilder.go b/sqlbuilder.go index a7a220c..1f79fcd 100644 --- a/sqlbuilder.go +++ b/sqlbuilder.go @@ -27,6 +27,7 @@ func New(cmd ...string) *Builder { params: make(map[string]any), } + // MySQL as default UseMySQL(b) for i, it := range cmd { @@ -173,7 +174,7 @@ func (b *Builder) Delete(table string) *Builder { func UsePostgres(b *Builder) { b.Quote = "`" b.Parameterize = func(name string, index int) string { - return "?" + strconv.Itoa(index) + return "$" + strconv.Itoa(index) } } diff --git a/sqlbuilder_test.go b/sqlbuilder_test.go index ef18306..e7d6924 100644 --- a/sqlbuilder_test.go +++ b/sqlbuilder_test.go @@ -34,7 +34,7 @@ func TestBuilder(t *testing.T) { { name: "build_with_input_tokens", build: func() *Builder { - b := New("SELECT * FROM orders_ as orders") + b := New("SELECT * FROM", "orders_ as orders") b.SQL(" WHERE orders.created>=now()") b.Input("yyyyMM", "202401") return b @@ -71,8 +71,11 @@ func TestBuilder(t *testing.T) { b := New("SELECT * FROM orders_ as orders LEFT JOIN users_") b.SQL(" ON users_.id=orders.user_id") b.SQL(" WHERE users_.id={user_id} and orders.user_id={user_id} and orders.status={order_status} and orders.created>={now}") - b.Input("dbid", "db2") - b.Input("yyyy", "2024") + b.Inputs(map[string]string{ + "dbid": "db2", + "yyyy": "2024", + }) + b.Param("order_status", 1) b.Param("now", now) b.Param("user_id", "u123456") diff --git a/sqlbuilder_where.go b/sqlbuilder_where.go index eaeabc4..ee02aa3 100644 --- a/sqlbuilder_where.go +++ b/sqlbuilder_where.go @@ -2,15 +2,12 @@ package sqle type WhereBuilder struct { *Builder - written bool - skipTimes int + written bool + shouldSkip bool } func (wb *WhereBuilder) If(predicate bool) *WhereBuilder { - if !predicate { - wb.skipTimes++ - } - + wb.shouldSkip = !predicate return wb } @@ -23,13 +20,12 @@ func (wb *WhereBuilder) Or(cmd string) *WhereBuilder { } func (wb *WhereBuilder) SQL(op string, cmd string) *WhereBuilder { - if cmd == "" { return wb } - if wb.skipTimes > 0 { - wb.skipTimes-- + if wb.shouldSkip { + wb.shouldSkip = false return wb }