Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce transaction for postgres, sqlite database #6

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
)

type Database interface {
BeginTx() (Database, error)
Commit() error
Rollback() error

Table(name string) Database

ID(id any) Database
Expand Down
35 changes: 26 additions & 9 deletions sql/postgres/lib/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,23 @@ func (stmt Statement) GenerateReadQuery() string {
return query
}

func (stmt Statement) ExecuteReadQuery(ctx context.Context, conn *sql.Conn, query string, doc any) error {
func (stmt Statement) ExecuteReadQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string, doc any) error {
//defer stmt.cleanup()

if stmt.showSQL {
log.Printf("Read Query: query: %v, args: %v\n", query, stmt.args)
}

rows, err := conn.QueryContext(ctx, query, stmt.args...)
var (
err error
rows *sql.Rows
)

if tx != nil {
rows, err = tx.QueryContext(ctx, query, stmt.args...)
} else {
rows, err = conn.QueryContext(ctx, query, stmt.args...)
}
if err != nil {
return err
}
Expand Down Expand Up @@ -183,25 +192,33 @@ func (stmt Statement) GenerateInsertQuery(doc any) string {
return query
}

func (stmt Statement) ExecuteInsertQuery(ctx context.Context, conn *sql.Conn, query string) (any, error) {
func (stmt Statement) ExecuteInsertQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string) (any, error) {
query += " RETURNING id;"
if stmt.showSQL {
log.Printf("Insert Query: query: %v, args: %v\n", query, stmt.args)
}

var id any
err := conn.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
var (
id any
err error
)
if tx != nil {
err = tx.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
} else {
err = conn.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
}
return id, err
}

func (stmt Statement) ExecuteWriteQuery(ctx context.Context, conn *sql.Conn, query string) (sql.Result, error) {
func (stmt Statement) ExecuteWriteQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string) (sql.Result, error) {
if stmt.showSQL {
log.Printf("Write Query: query: %v, args: %v\n", query, stmt.args)
}

result, err := conn.ExecContext(ctx, query, stmt.args...)

return result, err
if tx != nil {
return tx.ExecContext(ctx, query, stmt.args...)
}
return conn.ExecContext(ctx, query, stmt.args...)
}

func (stmt Statement) generateMustColMap() map[string]bool {
Expand Down
50 changes: 42 additions & 8 deletions sql/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
"errors"

isql "github.com/masudur-rahman/database/sql"
"github.com/masudur-rahman/database/sql/postgres/lib"
Expand All @@ -11,13 +12,46 @@ import (
type Postgres struct {
ctx context.Context
conn *sql.Conn
tx *sql.Tx
statement lib.Statement
}

func NewPostgres(ctx context.Context, conn *sql.Conn) Postgres {
return Postgres{ctx: ctx, conn: conn}
}

var _ isql.Database = Postgres{}

func (pg Postgres) BeginTx() (isql.Database, error) {
if pg.tx != nil {
return nil, errors.New("session already in progress")
}
tx, err := pg.conn.BeginTx(pg.ctx, nil)
if err != nil {
return nil, err
}
pg.tx = tx
return pg, nil
}

func (pg Postgres) Commit() error {
if pg.tx == nil {
return errors.New("no transaction in progress")
}
err := pg.tx.Commit()
pg.tx = nil
return err
}

func (pg Postgres) Rollback() error {
if pg.tx == nil {
return errors.New("no transaction in progress")
}
err := pg.tx.Rollback()
pg.tx = nil
return err
}

func (pg Postgres) Table(name string) isql.Database {
pg.statement = pg.statement.Table(name)
return pg
Expand Down Expand Up @@ -66,7 +100,7 @@ func (pg Postgres) FindOne(document any, filter ...any) (bool, error) {
}

query := pg.statement.GenerateReadQuery()
err := pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, query, document)
err := pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, document)
if err == nil {
return true, nil
}
Expand All @@ -81,19 +115,19 @@ func (pg Postgres) FindMany(documents any, filter ...any) error {
pg.statement = pg.statement.GenerateWhereClause(filter...)

query := pg.statement.GenerateReadQuery()
return pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, query, documents)
return pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, documents)
}

func (pg Postgres) InsertOne(document any) (id any, err error) {
query := pg.statement.GenerateInsertQuery(document)
return pg.statement.ExecuteInsertQuery(pg.ctx, pg.conn, query)
return pg.statement.ExecuteInsertQuery(pg.ctx, pg.conn, pg.tx, query)
}

func (pg Postgres) InsertMany(documents []any) ([]any, error) {
var ids []any
for _, doc := range documents {
query := pg.statement.GenerateInsertQuery(doc)
id, err := pg.statement.ExecuteInsertQuery(pg.ctx, pg.conn, query)
id, err := pg.statement.ExecuteInsertQuery(pg.ctx, pg.conn, pg.tx, query)
if err != nil {
return nil, err
}
Expand All @@ -110,7 +144,7 @@ func (pg Postgres) UpdateOne(document any) error {
}

query := pg.statement.GenerateUpdateQuery(document)
_, err := pg.statement.ExecuteWriteQuery(pg.ctx, pg.conn, query)
_, err := pg.statement.ExecuteWriteQuery(pg.ctx, pg.conn, pg.tx, query)
return err
}

Expand All @@ -121,7 +155,7 @@ func (pg Postgres) DeleteOne(filter ...any) error {
}

query := pg.statement.GenerateDeleteQuery()
_, err := pg.statement.ExecuteWriteQuery(pg.ctx, pg.conn, query)
_, err := pg.statement.ExecuteWriteQuery(pg.ctx, pg.conn, pg.tx, query)
return err
}

Expand All @@ -144,8 +178,8 @@ func (pg Postgres) Sync(tables ...any) error {
return nil
}

func (p Postgres) Close() error {
return p.conn.Close()
func (pg Postgres) Close() error {
return pg.conn.Close()
}

func (pg Postgres) cleanup() {
Expand Down
17 changes: 17 additions & 0 deletions sql/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ func TestPostgres_FindOne(t *testing.T) {
db, closer := initializeDB(t)
defer closer()

db, err := db.BeginTx()
assert.Nil(t, err)
defer func() {
err = db.Commit()
assert.Nil(t, err)
}()

user := TestUser{}
db = db.Table("test_user")

Expand Down Expand Up @@ -93,10 +100,13 @@ func TestPostgres_FindMany(t *testing.T) {
func TestPostgres_InsertOne(t *testing.T) {
db, closer := initializeDB(t)
defer closer()
db, err := db.BeginTx()
assert.Nil(t, err)

db = db.Table("test_user")
t.Run("insert data", func(t *testing.T) {
suffix := xid.New().String()
//suffix := "hello"
user := TestUser{
Name: "test-" + suffix,
FullName: "Test Name",
Expand All @@ -105,6 +115,13 @@ func TestPostgres_InsertOne(t *testing.T) {
id, err := db.InsertOne(&user)
assert.Nil(t, err)
assert.NotEqual(t, 0, id)
if err != nil {
err = db.Rollback()
assert.Nil(t, err)
}

err = db.Commit()
assert.Nil(t, err)
})
}

Expand Down
34 changes: 26 additions & 8 deletions sql/sqlite/lib/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,23 @@ func (stmt Statement) GenerateReadQuery() string {
return query
}

func (stmt Statement) ExecuteReadQuery(ctx context.Context, conn *sql.Conn, query string, doc any) error {
func (stmt Statement) ExecuteReadQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string, doc any) error {
//defer stmt.cleanup()

if stmt.showSQL {
log.Printf("Read Query: query: %v, args: %v\n", query, stmt.args)
}

rows, err := conn.QueryContext(ctx, query, stmt.args...)
var (
err error
rows *sql.Rows
)

if tx != nil {
rows, err = tx.QueryContext(ctx, query, stmt.args...)
} else {
rows, err = conn.QueryContext(ctx, query, stmt.args...)
}
if err != nil {
return err
}
Expand Down Expand Up @@ -183,25 +192,34 @@ func (stmt Statement) GenerateInsertQuery(doc any) string {
return query
}

func (stmt Statement) ExecuteInsertQuery(ctx context.Context, conn *sql.Conn, query string) (any, error) {
func (stmt Statement) ExecuteInsertQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string) (any, error) {
query += " RETURNING id;"
if stmt.showSQL {
log.Printf("Insert Query: query: %v, args: %v\n", query, stmt.args)
}

var id any
err := conn.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
var (
id any
err error
)
if tx != nil {
err = tx.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
} else {
err = conn.QueryRowContext(ctx, query, stmt.args...).Scan(&id)
}
return id, err
}

func (stmt Statement) ExecuteWriteQuery(ctx context.Context, conn *sql.Conn, query string) (sql.Result, error) {
func (stmt Statement) ExecuteWriteQuery(ctx context.Context, conn *sql.Conn, tx *sql.Tx, query string) (sql.Result, error) {
if stmt.showSQL {
log.Printf("Write Query: query: %v, args: %v\n", query, stmt.args)
}

result, err := conn.ExecContext(ctx, query, stmt.args...)
if tx != nil {
return tx.ExecContext(ctx, query, stmt.args...)
}

return result, err
return conn.ExecContext(ctx, query, stmt.args...)
}

func (stmt Statement) generateMustColMap() map[string]bool {
Expand Down
Loading
Loading