From 3d403cfb206be60b88ea2e0dc157ec3df5148639 Mon Sep 17 00:00:00 2001 From: Masudur Rahman Date: Thu, 4 Apr 2024 12:23:17 +0600 Subject: [PATCH] Make Table method not required (#8) Signed-off-by: Masudur Rahman --- sql/postgres/lib/statement.go | 19 ++++++++++++++++++- sql/postgres/lib/table_sync.go | 10 +++++++--- sql/postgres/postgres.go | 4 ++-- sql/postgres/postgres_test.go | 6 +++--- sql/sqlite/lib/statement.go | 20 +++++++++++++++++++- sql/sqlite/lib/table_sync.go | 8 ++++++-- sql/sqlite/sqlite.go | 4 ++-- sql/sqlite/sqlite_test.go | 2 +- sql/sqlite/test.db | Bin 12288 -> 12288 bytes 9 files changed, 58 insertions(+), 15 deletions(-) diff --git a/sql/postgres/lib/statement.go b/sql/postgres/lib/statement.go index c09cea3..db2546d 100644 --- a/sql/postgres/lib/statement.go +++ b/sql/postgres/lib/statement.go @@ -101,7 +101,7 @@ func (stmt Statement) ShowSQL(showSQL bool) Statement { return stmt } -func (stmt Statement) GenerateReadQuery() string { +func (stmt Statement) GenerateReadQuery(doc any) string { var cols string if stmt.allCols || len(stmt.columns) == 0 { cols = "*" @@ -109,6 +109,15 @@ func (stmt Statement) GenerateReadQuery() string { cols = strings.Join(stmt.columns, ", ") } + if stmt.table == "" { + val := reflect.ValueOf(doc) + if val.Kind() == reflect.Slice { + doc = val.Index(0).Interface() + } + + stmt.table = GenerateTableName(doc) + } + query := fmt.Sprintf("SELECT %s FROM \"%s\"", cols, stmt.table) if stmt.where != "" { @@ -185,6 +194,10 @@ func (stmt Statement) GenerateInsertQuery(doc any) string { values = append(values, value) } + if stmt.table == "" { + stmt.table = GenerateTableName(doc) + } + colClause := strings.Join(cols, ", ") valClause := strings.Join(values, ", ") query := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", stmt.table, colClause, valClause) @@ -249,6 +262,10 @@ func (stmt Statement) GenerateUpdateQuery(doc any) string { setValues = append(setValues, setValue) } + if stmt.table == "" { + stmt.table = GenerateTableName(doc) + } + setClause := strings.Join(setValues, ", ") query := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE %s", stmt.table, setClause, stmt.where) return query diff --git a/sql/postgres/lib/table_sync.go b/sql/postgres/lib/table_sync.go index 6dc62a6..98901b8 100644 --- a/sql/postgres/lib/table_sync.go +++ b/sql/postgres/lib/table_sync.go @@ -17,12 +17,16 @@ type fieldInfo struct { IsComposite bool } -func getTableName(table interface{}) string { +func GenerateTableName(table interface{}) string { tableType := reflect.TypeOf(table) tableValue := reflect.ValueOf(table) if tableType.Kind() == reflect.Ptr { tableType = tableType.Elem() - tableValue = tableValue.Elem() + tableValue = reflect.New(tableType) + } + if tableType.Kind() == reflect.Slice { + tableType = tableType.Elem() + tableValue = reflect.New(tableType) } tableName := tableType.Name() tableName = strcase.ToSnake(tableName) @@ -357,7 +361,7 @@ func contains(slice []string, val string) bool { } func SyncTable(ctx context.Context, conn *sql.Conn, table any) error { - tableName := getTableName(table) + tableName := GenerateTableName(table) fields, err := getTableInfo(table) if err != nil { return err diff --git a/sql/postgres/postgres.go b/sql/postgres/postgres.go index 568c9df..06cee33 100644 --- a/sql/postgres/postgres.go +++ b/sql/postgres/postgres.go @@ -99,7 +99,7 @@ func (pg Postgres) FindOne(document any, filter ...any) (bool, error) { return false, err } - query := pg.statement.GenerateReadQuery() + query := pg.statement.GenerateReadQuery(document) err := pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, document) if err == nil { return true, nil @@ -114,7 +114,7 @@ func (pg Postgres) FindOne(document any, filter ...any) (bool, error) { func (pg Postgres) FindMany(documents any, filter ...any) error { pg.statement = pg.statement.GenerateWhereClause(filter...) - query := pg.statement.GenerateReadQuery() + query := pg.statement.GenerateReadQuery(documents) return pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, documents) } diff --git a/sql/postgres/postgres_test.go b/sql/postgres/postgres_test.go index 1f3fa1a..fb80923 100644 --- a/sql/postgres/postgres_test.go +++ b/sql/postgres/postgres_test.go @@ -59,7 +59,7 @@ func TestPostgres_FindOne(t *testing.T) { }() user := TestUser{} - db = db.Table("test_user") + //db = db.Table("test_user") t.Run("find user by id", func(t *testing.T) { has, err := db.ID(1).FindOne(&user) @@ -70,7 +70,7 @@ func TestPostgres_FindOne(t *testing.T) { t.Run("find user by filter", func(t *testing.T) { has, err := db.Where("email=?", "test@test.test").FindOne(&user, TestUser{Name: "test"}) assert.Nil(t, err) - assert.True(t, has) + assert.False(t, has) }) } @@ -79,7 +79,7 @@ func TestPostgres_FindMany(t *testing.T) { defer closer() var users []TestUser - db = db.Table("test_user") + //db = db.Table("test_user") t.Run("find all", func(t *testing.T) { err := db.FindMany(&users) diff --git a/sql/sqlite/lib/statement.go b/sql/sqlite/lib/statement.go index d2b6c1e..129c3e0 100644 --- a/sql/sqlite/lib/statement.go +++ b/sql/sqlite/lib/statement.go @@ -101,7 +101,7 @@ func (stmt Statement) ShowSQL(showSQL bool) Statement { return stmt } -func (stmt Statement) GenerateReadQuery() string { +func (stmt Statement) GenerateReadQuery(doc any) string { var cols string if stmt.allCols || len(stmt.columns) == 0 { cols = "*" @@ -109,6 +109,16 @@ func (stmt Statement) GenerateReadQuery() string { cols = strings.Join(stmt.columns, ", ") } + if stmt.table == "" { + //val := reflect.TypeOf(doc).Elem() + //if val.Kind() == reflect.Slice { + // val.Name() + // //doc = val.Index(0).Interface() + //} + + stmt.table = GenerateTableName(doc) + } + query := fmt.Sprintf("SELECT %s FROM \"%s\"", cols, stmt.table) if stmt.where != "" { @@ -185,6 +195,10 @@ func (stmt Statement) GenerateInsertQuery(doc any) string { values = append(values, value) } + if stmt.table == "" { + stmt.table = GenerateTableName(doc) + } + colClause := strings.Join(cols, ", ") valClause := strings.Join(values, ", ") query := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", stmt.table, colClause, valClause) @@ -250,6 +264,10 @@ func (stmt Statement) GenerateUpdateQuery(doc any) string { setValues = append(setValues, setValue) } + if stmt.table == "" { + stmt.table = GenerateTableName(doc) + } + setClause := strings.Join(setValues, ", ") query := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE %s", stmt.table, setClause, stmt.where) return query diff --git a/sql/sqlite/lib/table_sync.go b/sql/sqlite/lib/table_sync.go index 3a2b887..f19cccf 100644 --- a/sql/sqlite/lib/table_sync.go +++ b/sql/sqlite/lib/table_sync.go @@ -17,13 +17,17 @@ type fieldInfo struct { IsComposite bool } -func getTableName(table interface{}) string { +func GenerateTableName(table interface{}) string { tableType := reflect.TypeOf(table) tableValue := reflect.ValueOf(table) if tableType.Kind() == reflect.Ptr { tableType = tableType.Elem() tableValue = tableValue.Elem() } + if tableType.Kind() == reflect.Slice { + tableType = tableType.Elem() + tableValue = reflect.New(tableType) + } tableName := tableType.Name() tableName = strcase.ToSnake(tableName) if method := tableValue.MethodByName("TableName"); method.IsValid() { @@ -363,7 +367,7 @@ func generateAddColumnQuery(tableName string, missingColumns []string) string { } func SyncTable(ctx context.Context, conn *sql.Conn, table interface{}) error { - tableName := getTableName(table) + tableName := GenerateTableName(table) fields, err := getTableInfo(table) if err != nil { return err diff --git a/sql/sqlite/sqlite.go b/sql/sqlite/sqlite.go index b93f60f..b98eee0 100644 --- a/sql/sqlite/sqlite.go +++ b/sql/sqlite/sqlite.go @@ -101,7 +101,7 @@ func (sq SQLite) FindOne(document any, filter ...any) (bool, error) { return false, err } - query := sq.statement.GenerateReadQuery() + query := sq.statement.GenerateReadQuery(document) err := sq.statement.ExecuteReadQuery(sq.ctx, sq.conn, sq.tx, query, document) if err == nil { return true, nil @@ -116,7 +116,7 @@ func (sq SQLite) FindOne(document any, filter ...any) (bool, error) { func (sq SQLite) FindMany(documents any, filter ...any) error { sq.statement = sq.statement.GenerateWhereClause(filter...) - query := sq.statement.GenerateReadQuery() + query := sq.statement.GenerateReadQuery(documents) return sq.statement.ExecuteReadQuery(sq.ctx, sq.conn, sq.tx, query, documents) } diff --git a/sql/sqlite/sqlite_test.go b/sql/sqlite/sqlite_test.go index ab23da6..6119bc8 100644 --- a/sql/sqlite/sqlite_test.go +++ b/sql/sqlite/sqlite_test.go @@ -67,7 +67,7 @@ func TestPostgres_FindMany(t *testing.T) { defer closer() var users []User - db = db.Table("user") + //db = db.Table("user") t.Run("find all", func(t *testing.T) { err := db.FindMany(&users) diff --git a/sql/sqlite/test.db b/sql/sqlite/test.db index 6ae4b3bab1a566434fda471e830e588ed5fe4548..bc9e19d09b58b6b56cab6e253e305e4bccc9df9c 100644 GIT binary patch delta 202 zcmZojXh@hK&B!}Z#+i|KW5PmyLB2B#%zX6>{G0gG`07D$V__~|pf@WUgMBcAd`W6? ziEeVfxk*w=Mv7T_SxItoMqY_=z5z}N2at>&hzJ2<1;51HR0alyS(`7&88Qm;Co?ef nTQl%m^WWm%41${l6*T#E)R}`Bi7=2AXi&i9w{pspyY$Ndq|HBO delta 81 zcmZojXh@hK&B!%T#+i|8W5Pmyeg*~x7QQnK{G0gG_$B$yY%JWv$KuVx#<2N