Skip to content

Commit

Permalink
Merge pull request #6 from masudur-rahman/txn
Browse files Browse the repository at this point in the history
Introduce transaction for postgres, sqlite database
  • Loading branch information
masudur-rahman authored Apr 3, 2024
2 parents 7eace7d + 248372e commit 702c1b2
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 33 deletions.
4 changes: 4 additions & 0 deletions sql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
)

type Database interface {
BeginTx() (Database, error)
Commit() error
Rollback() error

Table(name string) Database

ID(id any) Database
Expand Down
35 changes: 26 additions & 9 deletions sql/postgres/lib/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,23 @@ func (stmt Statement) GenerateReadQuery() string {
return query
}

func (stmt Statement) ExecuteReadQuery(ctx context.Context, conn *sql.Conn, query string, doc any) error {
func (stmt Statement) ExecuteReadQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string, doc any) error {
//defer stmt.cleanup()

if stmt.showSQL {
log.Printf("Read Query: query: %v, args: %v\n", query, stmt.args)
}

rows, err := conn.QueryContext(ctx, query, stmt.args...)
var (
err error
rows *sql.Rows
)

if tx != nil {
rows, err = tx.QueryContext(ctx, query, stmt.args...)
} else {
rows, err = conn.QueryContext(ctx, query, stmt.args...)
}
if err != nil {
return err
}
Expand Down Expand Up @@ -183,25 +192,33 @@ func (stmt Statement) GenerateInsertQuery(doc any) string {
return query
}

func (stmt Statement) ExecuteInsertQuery(ctx context.Context, conn *sql.Conn, query string) (any, error) {
func (stmt Statement) ExecuteInsertQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string) (any, error) {
query += " RETURNING id;"
if stmt.showSQL {
log.Printf("Insert Query: query: %v, args: %v\n", query, stmt.args)
}

var id any
err := conn.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
var (
id any
err error
)
if tx != nil {
err = tx.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
} else {
err = conn.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
}
return id, err
}

func (stmt Statement) ExecuteWriteQuery(ctx context.Context, conn *sql.Conn, query string) (sql.Result, error) {
func (stmt Statement) ExecuteWriteQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string) (sql.Result, error) {
if stmt.showSQL {
log.Printf("Write Query: query: %v, args: %v\n", query, stmt.args)
}

result, err := conn.ExecContext(ctx, query, stmt.args...)

return result, err
if tx != nil {
return tx.ExecContext(ctx, query, stmt.args...)
}
return conn.ExecContext(ctx, query, stmt.args...)
}

func (stmt Statement) generateMustColMap() map[string]bool {
Expand Down
50 changes: 42 additions & 8 deletions sql/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
"errors"

isql "github.com/masudur-rahman/database/sql"
"github.com/masudur-rahman/database/sql/postgres/lib"
Expand All @@ -11,13 +12,46 @@ import (
type Postgres struct {
ctx context.Context
conn *sql.Conn
tx *sql.Tx
statement lib.Statement
}

func NewPostgres(ctx context.Context, conn *sql.Conn) Postgres {
return Postgres{ctx: ctx, conn: conn}
}

var _ isql.Database = Postgres{}

func (pg Postgres) BeginTx() (isql.Database, error) {
if pg.tx != nil {
return nil, errors.New("session already in progress")
}
tx, err := pg.conn.BeginTx(pg.ctx, nil)
if err != nil {
return nil, err
}
pg.tx = tx
return pg, nil
}

func (pg Postgres) Commit() error {
if pg.tx == nil {
return errors.New("no transaction in progress")
}
err := pg.tx.Commit()
pg.tx = nil
return err
}

func (pg Postgres) Rollback() error {
if pg.tx == nil {
return errors.New("no transaction in progress")
}
err := pg.tx.Rollback()
pg.tx = nil
return err
}

func (pg Postgres) Table(name string) isql.Database {
pg.statement = pg.statement.Table(name)
return pg
Expand Down Expand Up @@ -66,7 +100,7 @@ func (pg Postgres) FindOne(document any, filter ...any) (bool, error) {
}

query := pg.statement.GenerateReadQuery()
err := pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, query, document)
err := pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, document)
if err == nil {
return true, nil
}
Expand All @@ -81,19 +115,19 @@ func (pg Postgres) FindMany(documents any, filter ...any) error {
pg.statement = pg.statement.GenerateWhereClause(filter...)

query := pg.statement.GenerateReadQuery()
return pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, query, documents)
return pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, documents)
}

func (pg Postgres) InsertOne(document any) (id any, err error) {
query := pg.statement.GenerateInsertQuery(document)
return pg.statement.ExecuteInsertQuery(pg.ctx, pg.conn, query)
return pg.statement.ExecuteInsertQuery(pg.ctx, pg.conn, pg.tx, query)
}

func (pg Postgres) InsertMany(documents []any) ([]any, error) {
var ids []any
for _, doc := range documents {
query := pg.statement.GenerateInsertQuery(doc)
id, err := pg.statement.ExecuteInsertQuery(pg.ctx, pg.conn, query)
id, err := pg.statement.ExecuteInsertQuery(pg.ctx, pg.conn, pg.tx, query)
if err != nil {
return nil, err
}
Expand All @@ -110,7 +144,7 @@ func (pg Postgres) UpdateOne(document any) error {
}

query := pg.statement.GenerateUpdateQuery(document)
_, err := pg.statement.ExecuteWriteQuery(pg.ctx, pg.conn, query)
_, err := pg.statement.ExecuteWriteQuery(pg.ctx, pg.conn, pg.tx, query)
return err
}

Expand All @@ -121,7 +155,7 @@ func (pg Postgres) DeleteOne(filter ...any) error {
}

query := pg.statement.GenerateDeleteQuery()
_, err := pg.statement.ExecuteWriteQuery(pg.ctx, pg.conn, query)
_, err := pg.statement.ExecuteWriteQuery(pg.ctx, pg.conn, pg.tx, query)
return err
}

Expand All @@ -144,8 +178,8 @@ func (pg Postgres) Sync(tables ...any) error {
return nil
}

func (p Postgres) Close() error {
return p.conn.Close()
func (pg Postgres) Close() error {
return pg.conn.Close()
}

func (pg Postgres) cleanup() {
Expand Down
17 changes: 17 additions & 0 deletions sql/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ func TestPostgres_FindOne(t *testing.T) {
db, closer := initializeDB(t)
defer closer()

db, err := db.BeginTx()
assert.Nil(t, err)
defer func() {
err = db.Commit()
assert.Nil(t, err)
}()

user := TestUser{}
db = db.Table("test_user")

Expand Down Expand Up @@ -93,10 +100,13 @@ func TestPostgres_FindMany(t *testing.T) {
func TestPostgres_InsertOne(t *testing.T) {
db, closer := initializeDB(t)
defer closer()
db, err := db.BeginTx()
assert.Nil(t, err)

db = db.Table("test_user")
t.Run("insert data", func(t *testing.T) {
suffix := xid.New().String()
//suffix := "hello"
user := TestUser{
Name: "test-" + suffix,
FullName: "Test Name",
Expand All @@ -105,6 +115,13 @@ func TestPostgres_InsertOne(t *testing.T) {
id, err := db.InsertOne(&user)
assert.Nil(t, err)
assert.NotEqual(t, 0, id)
if err != nil {
err = db.Rollback()
assert.Nil(t, err)
}

err = db.Commit()
assert.Nil(t, err)
})
}

Expand Down
34 changes: 26 additions & 8 deletions sql/sqlite/lib/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,23 @@ func (stmt Statement) GenerateReadQuery() string {
return query
}

func (stmt Statement) ExecuteReadQuery(ctx context.Context, conn *sql.Conn, query string, doc any) error {
func (stmt Statement) ExecuteReadQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string, doc any) error {
//defer stmt.cleanup()

if stmt.showSQL {
log.Printf("Read Query: query: %v, args: %v\n", query, stmt.args)
}

rows, err := conn.QueryContext(ctx, query, stmt.args...)
var (
err error
rows *sql.Rows
)

if tx != nil {
rows, err = tx.QueryContext(ctx, query, stmt.args...)
} else {
rows, err = conn.QueryContext(ctx, query, stmt.args...)
}
if err != nil {
return err
}
Expand Down Expand Up @@ -183,25 +192,34 @@ func (stmt Statement) GenerateInsertQuery(doc any) string {
return query
}

func (stmt Statement) ExecuteInsertQuery(ctx context.Context, conn *sql.Conn, query string) (any, error) {
func (stmt Statement) ExecuteInsertQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string) (any, error) {
query += " RETURNING id;"
if stmt.showSQL {
log.Printf("Insert Query: query: %v, args: %v\n", query, stmt.args)
}

var id any
err := conn.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
var (
id any
err error
)
if tx != nil {
err = tx.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
} else {
err = conn.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
}
return id, err
}

func (stmt Statement) ExecuteWriteQuery(ctx context.Context, conn *sql.Conn, query string) (sql.Result, error) {
func (stmt Statement) ExecuteWriteQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string) (sql.Result, error) {
if stmt.showSQL {
log.Printf("Write Query: query: %v, args: %v\n", query, stmt.args)
}

result, err := conn.ExecContext(ctx, query, stmt.args...)
if tx != nil {
return tx.ExecContext(ctx, query, stmt.args...)
}

return result, err
return conn.ExecContext(ctx, query, stmt.args...)
}

func (stmt Statement) generateMustColMap() map[string]bool {
Expand Down
Loading

0 comments on commit 702c1b2

Please sign in to comment.