Skip to content

Commit

Permalink
!feat(dtc):renamed Context and added DTC (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Apr 29, 2024
1 parent a6accdf commit f7859c0
Show file tree
Hide file tree
Showing 9 changed files with 458 additions and 30 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.4.7]
## [1.5.0]
### Changed
- !renamed `Context` with `Client` (#45)

### Added
- added `Connector` interface (#43)
- added nullable `Time` with better json support (#44)
- added `DTC` (#45)

### Changed
- renamed `BitBool` with shorter name `Bool` (#44)
- !renamed `BitBool` with shorter name `Bool` (#44)

## [1.4.6] - 2014-04-23
### Changed
Expand Down
26 changes: 13 additions & 13 deletions context.go → client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"time"
)

type Context struct {
type Client struct {
*sql.DB
sync.Mutex
_ noCopy
Expand All @@ -20,11 +20,11 @@ type Context struct {
Index int
}

func (db *Context) Query(query string, args ...any) (*Rows, error) {
func (db *Client) Query(query string, args ...any) (*Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}

func (db *Context) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) {
func (db *Client) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) {
query, args, err := b.Build()
if err != nil {
return nil, err
Expand All @@ -33,7 +33,7 @@ func (db *Context) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error)
return db.QueryContext(ctx, query, args...)
}

func (db *Context) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
func (db *Client) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
var rows *sql.Rows
var stmt *Stmt
var err error
Expand All @@ -59,11 +59,11 @@ func (db *Context) QueryContext(ctx context.Context, query string, args ...any)
return &Rows{Rows: rows, stmt: stmt, query: query}, nil
}

func (db *Context) QueryRow(query string, args ...any) *Row {
func (db *Client) QueryRow(query string, args ...any) *Row {
return db.QueryRowContext(context.Background(), query, args...)
}

func (db *Context) QueryRowBuilder(ctx context.Context, b *Builder) *Row {
func (db *Client) QueryRowBuilder(ctx context.Context, b *Builder) *Row {
query, args, err := b.Build()
if err != nil {
return &Row{
Expand All @@ -75,7 +75,7 @@ func (db *Context) QueryRowBuilder(ctx context.Context, b *Builder) *Row {
return db.QueryRowContext(ctx, query, args...)
}

func (db *Context) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
func (db *Client) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
var rows *sql.Rows
var stmt *Stmt
var err error
Expand Down Expand Up @@ -108,11 +108,11 @@ func (db *Context) QueryRowContext(ctx context.Context, query string, args ...an
}
}

func (db *Context) Exec(query string, args ...any) (sql.Result, error) {
func (db *Client) Exec(query string, args ...any) (sql.Result, error) {
return db.ExecContext(context.Background(), query, args...)
}

func (db *Context) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) {
func (db *Client) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) {
query, args, err := b.Build()
if err != nil {
return nil, err
Expand All @@ -121,7 +121,7 @@ func (db *Context) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, err
return db.ExecContext(ctx, query, args...)
}

func (db *Context) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
func (db *Client) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
if len(args) > 0 {
stmt, err := db.prepareStmt(ctx, query)
if err != nil {
Expand All @@ -135,12 +135,12 @@ func (db *Context) ExecContext(ctx context.Context, query string, args ...any) (
return db.DB.ExecContext(context.Background(), query, args...)
}

func (db *Context) Begin(opts *sql.TxOptions) (*Tx, error) {
func (db *Client) Begin(opts *sql.TxOptions) (*Tx, error) {
return db.BeginTx(context.TODO(), opts)

}

func (db *Context) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
func (db *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
Expand All @@ -149,7 +149,7 @@ func (db *Context) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error
return &Tx{Tx: tx, stmts: make(map[string]*sql.Stmt)}, nil
}

func (db *Context) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error {
func (db *Client) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions context_stmt.go → client_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (s *Stmt) Reuse() {
s.isUsing = false
}

func (db *Context) prepareStmt(ctx context.Context, query string) (*Stmt, error) {
func (db *Client) prepareStmt(ctx context.Context, query string) (*Stmt, error) {
db.stmtsMutex.Lock()
defer db.stmtsMutex.Unlock()
s, ok := db.stmts[query]
Expand All @@ -48,7 +48,7 @@ func (db *Context) prepareStmt(ctx context.Context, query string) (*Stmt, error)
return s, nil
}

func (db *Context) closeStaleStmt() {
func (db *Client) closeStaleStmt() {
db.stmtsMutex.Lock()
defer db.stmtsMutex.Unlock()

Expand All @@ -64,7 +64,7 @@ func (db *Context) closeStaleStmt() {

}

func (db *Context) checkIdleStmt() {
func (db *Client) checkIdleStmt() {
delay := time.NewTicker(db.stmtMaxIdleTime)
defer delay.Stop()

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

// Connector represents a database connector that provides methods for executing queries and commands.
// Context and Tx both implement this interface.
// Conn and Tx both implement this interface.
type Connector interface {
// Query executes a query that returns multiple rows.
// It takes a query string and optional arguments.
Expand Down
14 changes: 7 additions & 7 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ var (

// DB represents a database connection pool with sharding support.
type DB struct {
*Context
*Client
_ noCopy //nolint: unused

mu sync.RWMutex
dhts map[string]*shardid.DHT
dbs []*Context
dbs []*Client
}

// Open creates a new DB instance with the provided database connections.
Expand All @@ -31,7 +31,7 @@ func Open(dbs ...*sql.DB) *DB {
}

for i, db := range dbs {
ctx := &Context{
ctx := &Client{
DB: db,
Index: i,
stmts: make(map[string]*Stmt),
Expand All @@ -41,7 +41,7 @@ func Open(dbs ...*sql.DB) *DB {
go ctx.checkIdleStmt()
}

d.Context = d.dbs[0]
d.Client = d.dbs[0]

return d
}
Expand All @@ -54,7 +54,7 @@ func (db *DB) Add(dbs ...*sql.DB) {
n := len(db.dbs)

for i, d := range dbs {
ctx := &Context{
ctx := &Client{
DB: d,
Index: n + i,
stmts: make(map[string]*Stmt),
Expand All @@ -66,7 +66,7 @@ func (db *DB) Add(dbs ...*sql.DB) {
}

// On selects the database context based on the shardid ID.
func (db *DB) On(id shardid.ID) *Context {
func (db *DB) On(id shardid.ID) *Client {
db.mu.RLock()
defer db.mu.RUnlock()

Expand All @@ -90,7 +90,7 @@ func (db *DB) GetDHT(name string) *shardid.DHT {
}

// OnDHT selects the database context based on the DHT (Distributed Hash Table) key.
func (db *DB) OnDHT(key string, names ...string) (*Context, error) {
func (db *DB) OnDHT(key string, names ...string) (*Client, error) {
db.mu.RLock()
defer db.mu.RUnlock()

Expand Down
108 changes: 108 additions & 0 deletions dtc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package sqle

import (
"context"
"database/sql"
)

// DTC Distributed Transaction Coordinator
type DTC struct {
ctx context.Context
opts *sql.TxOptions

sessions []*session
}

// session represents a transaction session.
type session struct {
committed bool
client *Client
tx *Tx
exec []func(context.Context, Connector) error
revert []func(context.Context, Connector) error
}

// NewDTC creates a new instance of DTC.
func NewDTC(ctx context.Context, opts *sql.TxOptions) *DTC {
return &DTC{
ctx: ctx,
opts: opts,
}
}

// Prepare adds a new transaction session to the DTC.
func (d *DTC) Prepare(client *Client, exec func(ctx context.Context, conn Connector) error, revert func(ctx context.Context, conn Connector) error) {
for _, s := range d.sessions {
if s.client == client {
s.exec = append(s.exec, exec)
s.revert = append(s.revert, revert)
return
}
}

s := &session{
committed: false,
client: client,
exec: []func(ctx context.Context, c Connector) error{
exec,
},
revert: []func(ctx context.Context, c Connector) error{
revert,
},
}

d.sessions = append(d.sessions, s)

}

// Commit commits all the prepared transactions in the DTC.
func (d *DTC) Commit() error {
for _, s := range d.sessions {
tx, err := s.client.BeginTx(d.ctx, d.opts)
if err != nil {
return err
}

s.tx = tx

for _, exec := range s.exec {
err = exec(d.ctx, tx)
if err != nil {
return err
}
}
}

for _, s := range d.sessions {
err := s.tx.Commit()
if err != nil {
return err
}

s.committed = true
}

return nil
}

// Rollback rolls back all the prepared transactions in the DTC.
func (d *DTC) Rollback() []error {
var errs []error

for _, s := range d.sessions {
if s.committed {
for _, revert := range s.revert {
if err := revert(d.ctx, s.client); err != nil {
errs = append(errs, err)
}
}

} else {
if err := s.tx.Rollback(); err != nil {
errs = append(errs, err)
}
}
}

return errs
}
Loading

0 comments on commit f7859c0

Please sign in to comment.