diff --git a/CHANGELOG.md b/CHANGELOG.md index 8031c00..b85f01c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,13 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [1.4.7] +## [1.5.0] +### Changed +- !renamed `Context` with `Client` (#45) + ### Added - added `Connector` interface (#43) - added nullable `Time` with better json support (#44) +- added `DTC` (#45) ### Changed -- renamed `BitBool` with shorter name `Bool` (#44) +- !renamed `BitBool` with shorter name `Bool` (#44) ## [1.4.6] - 2014-04-23 ### Changed diff --git a/context.go b/client.go similarity index 70% rename from context.go rename to client.go index 6ffe80d..a7d4c27 100644 --- a/context.go +++ b/client.go @@ -8,7 +8,7 @@ import ( "time" ) -type Context struct { +type Client struct { *sql.DB sync.Mutex _ noCopy @@ -20,11 +20,11 @@ type Context struct { Index int } -func (db *Context) Query(query string, args ...any) (*Rows, error) { +func (db *Client) Query(query string, args ...any) (*Rows, error) { return db.QueryContext(context.Background(), query, args...) } -func (db *Context) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) { +func (db *Client) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) { query, args, err := b.Build() if err != nil { return nil, err @@ -33,7 +33,7 @@ func (db *Context) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) return db.QueryContext(ctx, query, args...) } -func (db *Context) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { +func (db *Client) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { var rows *sql.Rows var stmt *Stmt var err error @@ -59,11 +59,11 @@ func (db *Context) QueryContext(ctx context.Context, query string, args ...any) return &Rows{Rows: rows, stmt: stmt, query: query}, nil } -func (db *Context) QueryRow(query string, args ...any) *Row { +func (db *Client) QueryRow(query string, args ...any) *Row { return db.QueryRowContext(context.Background(), query, args...) } -func (db *Context) QueryRowBuilder(ctx context.Context, b *Builder) *Row { +func (db *Client) QueryRowBuilder(ctx context.Context, b *Builder) *Row { query, args, err := b.Build() if err != nil { return &Row{ @@ -75,7 +75,7 @@ func (db *Context) QueryRowBuilder(ctx context.Context, b *Builder) *Row { return db.QueryRowContext(ctx, query, args...) } -func (db *Context) QueryRowContext(ctx context.Context, query string, args ...any) *Row { +func (db *Client) QueryRowContext(ctx context.Context, query string, args ...any) *Row { var rows *sql.Rows var stmt *Stmt var err error @@ -108,11 +108,11 @@ func (db *Context) QueryRowContext(ctx context.Context, query string, args ...an } } -func (db *Context) Exec(query string, args ...any) (sql.Result, error) { +func (db *Client) Exec(query string, args ...any) (sql.Result, error) { return db.ExecContext(context.Background(), query, args...) } -func (db *Context) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) { +func (db *Client) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) { query, args, err := b.Build() if err != nil { return nil, err @@ -121,7 +121,7 @@ func (db *Context) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, err return db.ExecContext(ctx, query, args...) } -func (db *Context) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { +func (db *Client) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { if len(args) > 0 { stmt, err := db.prepareStmt(ctx, query) if err != nil { @@ -135,12 +135,12 @@ func (db *Context) ExecContext(ctx context.Context, query string, args ...any) ( return db.DB.ExecContext(context.Background(), query, args...) } -func (db *Context) Begin(opts *sql.TxOptions) (*Tx, error) { +func (db *Client) Begin(opts *sql.TxOptions) (*Tx, error) { return db.BeginTx(context.TODO(), opts) } -func (db *Context) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { +func (db *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { tx, err := db.DB.BeginTx(ctx, opts) if err != nil { return nil, err @@ -149,7 +149,7 @@ func (db *Context) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error return &Tx{Tx: tx, stmts: make(map[string]*sql.Stmt)}, nil } -func (db *Context) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error { +func (db *Client) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error { tx, err := db.BeginTx(ctx, opts) if err != nil { return err diff --git a/context_stmt.go b/client_stmt.go similarity index 86% rename from context_stmt.go rename to client_stmt.go index 3c94352..fa31618 100644 --- a/context_stmt.go +++ b/client_stmt.go @@ -21,7 +21,7 @@ func (s *Stmt) Reuse() { s.isUsing = false } -func (db *Context) prepareStmt(ctx context.Context, query string) (*Stmt, error) { +func (db *Client) prepareStmt(ctx context.Context, query string) (*Stmt, error) { db.stmtsMutex.Lock() defer db.stmtsMutex.Unlock() s, ok := db.stmts[query] @@ -48,7 +48,7 @@ func (db *Context) prepareStmt(ctx context.Context, query string) (*Stmt, error) return s, nil } -func (db *Context) closeStaleStmt() { +func (db *Client) closeStaleStmt() { db.stmtsMutex.Lock() defer db.stmtsMutex.Unlock() @@ -64,7 +64,7 @@ func (db *Context) closeStaleStmt() { } -func (db *Context) checkIdleStmt() { +func (db *Client) checkIdleStmt() { delay := time.NewTicker(db.stmtMaxIdleTime) defer delay.Stop() diff --git a/context_stmt_test.go b/client_stmt_test.go similarity index 100% rename from context_stmt_test.go rename to client_stmt_test.go diff --git a/connector.go b/connector.go index cc6107c..514e683 100644 --- a/connector.go +++ b/connector.go @@ -6,7 +6,7 @@ import ( ) // Connector represents a database connector that provides methods for executing queries and commands. -// Context and Tx both implement this interface. +// Conn and Tx both implement this interface. type Connector interface { // Query executes a query that returns multiple rows. // It takes a query string and optional arguments. diff --git a/db.go b/db.go index 13d8208..b201fac 100644 --- a/db.go +++ b/db.go @@ -16,12 +16,12 @@ var ( // DB represents a database connection pool with sharding support. type DB struct { - *Context + *Client _ noCopy //nolint: unused mu sync.RWMutex dhts map[string]*shardid.DHT - dbs []*Context + dbs []*Client } // Open creates a new DB instance with the provided database connections. @@ -31,7 +31,7 @@ func Open(dbs ...*sql.DB) *DB { } for i, db := range dbs { - ctx := &Context{ + ctx := &Client{ DB: db, Index: i, stmts: make(map[string]*Stmt), @@ -41,7 +41,7 @@ func Open(dbs ...*sql.DB) *DB { go ctx.checkIdleStmt() } - d.Context = d.dbs[0] + d.Client = d.dbs[0] return d } @@ -54,7 +54,7 @@ func (db *DB) Add(dbs ...*sql.DB) { n := len(db.dbs) for i, d := range dbs { - ctx := &Context{ + ctx := &Client{ DB: d, Index: n + i, stmts: make(map[string]*Stmt), @@ -66,7 +66,7 @@ func (db *DB) Add(dbs ...*sql.DB) { } // On selects the database context based on the shardid ID. -func (db *DB) On(id shardid.ID) *Context { +func (db *DB) On(id shardid.ID) *Client { db.mu.RLock() defer db.mu.RUnlock() @@ -90,7 +90,7 @@ func (db *DB) GetDHT(name string) *shardid.DHT { } // OnDHT selects the database context based on the DHT (Distributed Hash Table) key. -func (db *DB) OnDHT(key string, names ...string) (*Context, error) { +func (db *DB) OnDHT(key string, names ...string) (*Client, error) { db.mu.RLock() defer db.mu.RUnlock() diff --git a/dtc.go b/dtc.go new file mode 100644 index 0000000..acd0658 --- /dev/null +++ b/dtc.go @@ -0,0 +1,108 @@ +package sqle + +import ( + "context" + "database/sql" +) + +// DTC Distributed Transaction Coordinator +type DTC struct { + ctx context.Context + opts *sql.TxOptions + + sessions []*session +} + +// session represents a transaction session. +type session struct { + committed bool + client *Client + tx *Tx + exec []func(context.Context, Connector) error + revert []func(context.Context, Connector) error +} + +// NewDTC creates a new instance of DTC. +func NewDTC(ctx context.Context, opts *sql.TxOptions) *DTC { + return &DTC{ + ctx: ctx, + opts: opts, + } +} + +// Prepare adds a new transaction session to the DTC. +func (d *DTC) Prepare(client *Client, exec func(ctx context.Context, conn Connector) error, revert func(ctx context.Context, conn Connector) error) { + for _, s := range d.sessions { + if s.client == client { + s.exec = append(s.exec, exec) + s.revert = append(s.revert, revert) + return + } + } + + s := &session{ + committed: false, + client: client, + exec: []func(ctx context.Context, c Connector) error{ + exec, + }, + revert: []func(ctx context.Context, c Connector) error{ + revert, + }, + } + + d.sessions = append(d.sessions, s) + +} + +// Commit commits all the prepared transactions in the DTC. +func (d *DTC) Commit() error { + for _, s := range d.sessions { + tx, err := s.client.BeginTx(d.ctx, d.opts) + if err != nil { + return err + } + + s.tx = tx + + for _, exec := range s.exec { + err = exec(d.ctx, tx) + if err != nil { + return err + } + } + } + + for _, s := range d.sessions { + err := s.tx.Commit() + if err != nil { + return err + } + + s.committed = true + } + + return nil +} + +// Rollback rolls back all the prepared transactions in the DTC. +func (d *DTC) Rollback() []error { + var errs []error + + for _, s := range d.sessions { + if s.committed { + for _, revert := range s.revert { + if err := revert(d.ctx, s.client); err != nil { + errs = append(errs, err) + } + } + + } else { + if err := s.tx.Rollback(); err != nil { + errs = append(errs, err) + } + } + } + + return errs +} diff --git a/dtc_test.go b/dtc_test.go new file mode 100644 index 0000000..2459c1c --- /dev/null +++ b/dtc_test.go @@ -0,0 +1,316 @@ +package sqle + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDTCWithDB(t *testing.T) { + os.Remove("dtc_db.db") + + d, err := sql.Open("sqlite3", "file:dtc_1.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d.Exec("CREATE TABLE `dtc_1` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + db := Open(d) + + var tests = []struct { + name string + setup func() *DTC + assert func(ra *require.Assertions) + }{ + { + name: "multiple_txs_commit_should_work", + setup: func() *DTC { + dtc := NewDTC(context.Background(), nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + + return err + }, nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + + return err + }, nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + + return err + }, nil) + + return dtc + }, + assert: func(ra *require.Assertions) { + var id int + err := db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 1).Scan(&id) + ra.NoError(err) + ra.Equal(1, id) + + err = db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 2).Scan(&id) + ra.NoError(err) + ra.Equal(2, id) + + err = db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 3).Scan(&id) + ra.NoError(err) + ra.Equal(3, id) + }, + }, + { + name: "multiple_txs_rollback_should_work", + setup: func() *DTC { + dtc := NewDTC(context.Background(), nil) + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 11, "1@mail.com") + + return err + }, nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 12, "2@mail.com") + + return err + }, nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 13) + + return err + }, nil) + + return dtc + }, + assert: func(ra *require.Assertions) { + var id int + err := db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 11).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 12).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 13).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dtc := test.setup() + + err := dtc.Commit() + if err != nil { + dtc.Rollback() + } + + test.assert(require.New(t)) + }) + } +} + +func TestDTCWithDBs(t *testing.T) { + os.Remove("dtc_dbs_1.db") + os.Remove("dtc_dbs_2.db") + os.Remove("dtc_dbs_3.db") + + d1, err := sql.Open("sqlite3", "file:dtc_dbs_1.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d1.Exec("CREATE TABLE `dtc_1` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + d2, err := sql.Open("sqlite3", "file:dtc_dbs_2.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d2.Exec("CREATE TABLE `dtc_2` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + d3, err := sql.Open("sqlite3", "file:dtc_dbs_3.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d3.Exec("CREATE TABLE `dtc_3` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + db1 := Open(d1) + db2 := Open(d2) + db3 := Open(d3) + + var tests = []struct { + name string + setup func() *DTC + assert func(ra *require.Assertions) + }{ + { + name: "multiple_txs_commit_should_work", + setup: func() *DTC { + dtc := NewDTC(context.Background(), nil) + + dtc.Prepare(db1.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1` (`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + + return err + }, nil) + + dtc.Prepare(db2.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_2` (`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + + return err + }, nil) + + dtc.Prepare(db3.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_3` (`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + + return err + }, nil) + + return dtc + }, + assert: func(ra *require.Assertions) { + var id int + err := db1.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 1).Scan(&id) + ra.NoError(err) + ra.Equal(1, id) + + err = db2.QueryRow("SELECT id FROM `dtc_2` WHERE id=?", 2).Scan(&id) + ra.NoError(err) + ra.Equal(2, id) + + err = db3.QueryRow("SELECT id FROM `dtc_3` WHERE id=?", 3).Scan(&id) + ra.NoError(err) + ra.Equal(3, id) + }, + }, + { + name: "multiple_txs_rollback_should_work", + setup: func() *DTC { + dtc := NewDTC(context.Background(), nil) + dtc.Prepare(db1.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 11, "1@mail.com") + + return err + }, nil) + + dtc.Prepare(db2.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_2`(`id`,`email`) VALUES(?,?)", 12, "2@mail.com") + + return err + }, nil) + + dtc.Prepare(db3.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_3`(`id`,`email`) VALUES(?,?)", 13) + + return err + }, nil) + + return dtc + }, + assert: func(ra *require.Assertions) { + var id int + err := db1.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 11).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db2.QueryRow("SELECT id FROM `dtc_2` WHERE id=?", 12).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db3.QueryRow("SELECT id FROM `dtc_3` WHERE id=?", 13).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dtc := test.setup() + + err := dtc.Commit() + if err != nil { + dtc.Rollback() + } + + test.assert(require.New(t)) + }) + } +} + +func TestDTCRevert(t *testing.T) { + os.Remove("dtc_revert_1.db") + os.Remove("dtc_revert_2.db") + os.Remove("dtc_revert_3.db") + + d1, err := sql.Open("sqlite3", "file:dtc_revert_1.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d1.Exec("CREATE TABLE `dtc_1` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + d2, err := sql.Open("sqlite3", "file:dtc_revert_2.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d2.Exec("CREATE TABLE `dtc_2` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + d3, err := sql.Open("sqlite3", "file:dtc_revert_3.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d3.Exec("CREATE TABLE `dtc_3` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + db1 := Open(d1) + db2 := Open(d2) + db3 := Open(d3) + + dtc := NewDTC(context.Background(), nil) + + dtc.Prepare(db1.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1` (`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + return err + }, func(ctx context.Context, c Connector) error { + _, err := c.Exec("DELETE FROM `dtc_1` WHERE id=?", 1) + return err + }) + + dtc.Prepare(db2.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_2` (`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + + return err + }, func(ctx context.Context, c Connector) error { + _, err := c.Exec("DELETE FROM `dtc_2` WHERE id=?", 2) + return err + }) + + dtc.Prepare(db3.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_3` (`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + + return err + }, func(ctx context.Context, c Connector) error { + _, err := c.Exec("DELETE FROM `dtc_3` WHERE id=?", 3) + return err + }) + + ra := require.New(t) + err = dtc.Commit() + ra.NoError(err) + + errs := dtc.Rollback() + ra.Len(errs, 0) + + var id int + err = db1.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 11).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db2.QueryRow("SELECT id FROM `dtc_2` WHERE id=?", 12).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db3.QueryRow("SELECT id FROM `dtc_3` WHERE id=?", 13).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + +} diff --git a/queryer_mapr.go b/queryer_mapr.go index b1e3a8c..bf6c00d 100644 --- a/queryer_mapr.go +++ b/queryer_mapr.go @@ -11,7 +11,7 @@ import ( // MapR is a Map/Reduce Query Provider based on databases. type MapR[T any] struct { - dbs []*Context + dbs []*Client } // First executes the query and returns the first result. @@ -28,7 +28,7 @@ func (q *MapR[T]) First(ctx context.Context, rotatedTables []string, b *Builder) for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Context, qr string) func(context.Context) (T, error) { + w.Add(func(db *Client, qr string) func(context.Context) (T, error) { return func(ctx context.Context) (T, error) { var t T err := db.QueryRowContext(ctx, qr, args...).Bind(&t) @@ -59,7 +59,7 @@ func (q *MapR[T]) Count(ctx context.Context, rotatedTables []string, b *Builder) for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Context, qr string) func(context.Context) (int64, error) { + w.Add(func(db *Client, qr string) func(context.Context) (int64, error) { return func(ctx context.Context) (int64, error) { var i int64 err := db.QueryRowContext(ctx, qr, args...).Scan(&i) @@ -102,7 +102,7 @@ func (q *MapR[T]) Query(ctx context.Context, rotatedTables []string, b *Builder, for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Context, qr string) func(context.Context) ([]T, error) { + w.Add(func(db *Client, qr string) func(context.Context) ([]T, error) { return func(context.Context) ([]T, error) { var t []T rows, err := db.QueryContext(ctx, qr, args...)