Skip to content

Commit

Permalink
Make Table method not required (#8)
Browse files Browse the repository at this point in the history
Signed-off-by: Masudur Rahman <[email protected]>
  • Loading branch information
masudur-rahman committed Apr 4, 2024
1 parent f04ecee commit 3d403cf
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 15 deletions.
19 changes: 18 additions & 1 deletion sql/postgres/lib/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,23 @@ func (stmt Statement) ShowSQL(showSQL bool) Statement {
return stmt
}

func (stmt Statement) GenerateReadQuery() string {
func (stmt Statement) GenerateReadQuery(doc any) string {
var cols string
if stmt.allCols || len(stmt.columns) == 0 {
cols = "*"
} else {
cols = strings.Join(stmt.columns, ", ")
}

if stmt.table == "" {
val := reflect.ValueOf(doc)
if val.Kind() == reflect.Slice {
doc = val.Index(0).Interface()
}

stmt.table = GenerateTableName(doc)
}

query := fmt.Sprintf("SELECT %s FROM \"%s\"", cols, stmt.table)

if stmt.where != "" {
Expand Down Expand Up @@ -185,6 +194,10 @@ func (stmt Statement) GenerateInsertQuery(doc any) string {
values = append(values, value)
}

if stmt.table == "" {
stmt.table = GenerateTableName(doc)
}

colClause := strings.Join(cols, ", ")
valClause := strings.Join(values, ", ")
query := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", stmt.table, colClause, valClause)
Expand Down Expand Up @@ -249,6 +262,10 @@ func (stmt Statement) GenerateUpdateQuery(doc any) string {
setValues = append(setValues, setValue)
}

if stmt.table == "" {
stmt.table = GenerateTableName(doc)
}

setClause := strings.Join(setValues, ", ")
query := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE %s", stmt.table, setClause, stmt.where)
return query
Expand Down
10 changes: 7 additions & 3 deletions sql/postgres/lib/table_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ type fieldInfo struct {
IsComposite bool
}

func getTableName(table interface{}) string {
func GenerateTableName(table interface{}) string {
tableType := reflect.TypeOf(table)
tableValue := reflect.ValueOf(table)
if tableType.Kind() == reflect.Ptr {
tableType = tableType.Elem()
tableValue = tableValue.Elem()
tableValue = reflect.New(tableType)
}
if tableType.Kind() == reflect.Slice {
tableType = tableType.Elem()
tableValue = reflect.New(tableType)
}
tableName := tableType.Name()
tableName = strcase.ToSnake(tableName)
Expand Down Expand Up @@ -357,7 +361,7 @@ func contains(slice []string, val string) bool {
}

func SyncTable(ctx context.Context, conn *sql.Conn, table any) error {
tableName := getTableName(table)
tableName := GenerateTableName(table)
fields, err := getTableInfo(table)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions sql/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (pg Postgres) FindOne(document any, filter ...any) (bool, error) {
return false, err
}

query := pg.statement.GenerateReadQuery()
query := pg.statement.GenerateReadQuery(document)
err := pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, document)
if err == nil {
return true, nil
Expand All @@ -114,7 +114,7 @@ func (pg Postgres) FindOne(document any, filter ...any) (bool, error) {
func (pg Postgres) FindMany(documents any, filter ...any) error {
pg.statement = pg.statement.GenerateWhereClause(filter...)

query := pg.statement.GenerateReadQuery()
query := pg.statement.GenerateReadQuery(documents)
return pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, documents)
}

Expand Down
6 changes: 3 additions & 3 deletions sql/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestPostgres_FindOne(t *testing.T) {
}()

user := TestUser{}
db = db.Table("test_user")
//db = db.Table("test_user")

t.Run("find user by id", func(t *testing.T) {
has, err := db.ID(1).FindOne(&user)
Expand All @@ -70,7 +70,7 @@ func TestPostgres_FindOne(t *testing.T) {
t.Run("find user by filter", func(t *testing.T) {
has, err := db.Where("email=?", "[email protected]").FindOne(&user, TestUser{Name: "test"})
assert.Nil(t, err)
assert.True(t, has)
assert.False(t, has)
})
}

Expand All @@ -79,7 +79,7 @@ func TestPostgres_FindMany(t *testing.T) {
defer closer()

var users []TestUser
db = db.Table("test_user")
//db = db.Table("test_user")

t.Run("find all", func(t *testing.T) {
err := db.FindMany(&users)
Expand Down
20 changes: 19 additions & 1 deletion sql/sqlite/lib/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,24 @@ func (stmt Statement) ShowSQL(showSQL bool) Statement {
return stmt
}

func (stmt Statement) GenerateReadQuery() string {
func (stmt Statement) GenerateReadQuery(doc any) string {
var cols string
if stmt.allCols || len(stmt.columns) == 0 {
cols = "*"
} else {
cols = strings.Join(stmt.columns, ", ")
}

if stmt.table == "" {
//val := reflect.TypeOf(doc).Elem()
//if val.Kind() == reflect.Slice {
// val.Name()
// //doc = val.Index(0).Interface()
//}

stmt.table = GenerateTableName(doc)
}

query := fmt.Sprintf("SELECT %s FROM \"%s\"", cols, stmt.table)

if stmt.where != "" {
Expand Down Expand Up @@ -185,6 +195,10 @@ func (stmt Statement) GenerateInsertQuery(doc any) string {
values = append(values, value)
}

if stmt.table == "" {
stmt.table = GenerateTableName(doc)
}

colClause := strings.Join(cols, ", ")
valClause := strings.Join(values, ", ")
query := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", stmt.table, colClause, valClause)
Expand Down Expand Up @@ -250,6 +264,10 @@ func (stmt Statement) GenerateUpdateQuery(doc any) string {
setValues = append(setValues, setValue)
}

if stmt.table == "" {
stmt.table = GenerateTableName(doc)
}

setClause := strings.Join(setValues, ", ")
query := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE %s", stmt.table, setClause, stmt.where)
return query
Expand Down
8 changes: 6 additions & 2 deletions sql/sqlite/lib/table_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ type fieldInfo struct {
IsComposite bool
}

func getTableName(table interface{}) string {
func GenerateTableName(table interface{}) string {
tableType := reflect.TypeOf(table)
tableValue := reflect.ValueOf(table)
if tableType.Kind() == reflect.Ptr {
tableType = tableType.Elem()
tableValue = tableValue.Elem()
}
if tableType.Kind() == reflect.Slice {
tableType = tableType.Elem()
tableValue = reflect.New(tableType)
}
tableName := tableType.Name()
tableName = strcase.ToSnake(tableName)
if method := tableValue.MethodByName("TableName"); method.IsValid() {
Expand Down Expand Up @@ -363,7 +367,7 @@ func generateAddColumnQuery(tableName string, missingColumns []string) string {
}

func SyncTable(ctx context.Context, conn *sql.Conn, table interface{}) error {
tableName := getTableName(table)
tableName := GenerateTableName(table)
fields, err := getTableInfo(table)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions sql/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (sq SQLite) FindOne(document any, filter ...any) (bool, error) {
return false, err
}

query := sq.statement.GenerateReadQuery()
query := sq.statement.GenerateReadQuery(document)
err := sq.statement.ExecuteReadQuery(sq.ctx, sq.conn, sq.tx, query, document)
if err == nil {
return true, nil
Expand All @@ -116,7 +116,7 @@ func (sq SQLite) FindOne(document any, filter ...any) (bool, error) {
func (sq SQLite) FindMany(documents any, filter ...any) error {
sq.statement = sq.statement.GenerateWhereClause(filter...)

query := sq.statement.GenerateReadQuery()
query := sq.statement.GenerateReadQuery(documents)
return sq.statement.ExecuteReadQuery(sq.ctx, sq.conn, sq.tx, query, documents)
}

Expand Down
2 changes: 1 addition & 1 deletion sql/sqlite/sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestPostgres_FindMany(t *testing.T) {
defer closer()

var users []User
db = db.Table("user")
//db = db.Table("user")

t.Run("find all", func(t *testing.T) {
err := db.FindMany(&users)
Expand Down
Binary file modified sql/sqlite/test.db
Binary file not shown.

0 comments on commit 3d403cf

Please sign in to comment.