From 2c5a00d91619af9451f5e75a78b9b6532815d3ba Mon Sep 17 00:00:00 2001
From: aceforeverd <teapot@aceforeverd.com>
Date: Sat, 27 Apr 2024 00:43:16 +0800
Subject: [PATCH] feat: timestamp & date (#8)

- fix timestamp or date type as go query parameters.
- basic facility to support SQL Null as input or output.
- more tests
---
 .github/workflows/go.yml |   4 +-
 conn.go                  |  56 +++++++---
 conn_test.go             |  50 +++++++--
 encode.go                |  15 +++
 go.mod                   |   2 +-
 go_sdk_test.go           | 233 ++++++++++++++++++++++++++++-----------
 types.go                 |  64 ++++++-----
 7 files changed, 299 insertions(+), 125 deletions(-)
 create mode 100644 encode.go

diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index bbb8cca..daf06c5 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -15,7 +15,7 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '^1.18'
+          go-version: '^1.22'
 
       - name: OpenMLDB cluster
         run: |
@@ -33,7 +33,7 @@ jobs:
           docker compose -f docker-compose.yml exec openmldb-ns1 /opt/openmldb/bin/openmldb --zk_cluster=openmldb-zk:2181 --zk_root_path=/openmldb --role=sql_client --cmd 'SET GLOBAL execute_mode = "online"'
 
       - name: go test
-        run: go test ./... -race -covermode=atomic -coverprofile=coverage.out
+        run: go test ./... -race -covermode=atomic -coverprofile=coverage.out -v
 
       - name: Coverage
         uses: codecov/codecov-action@v4
diff --git a/conn.go b/conn.go
index a2e75c9..7beccb0 100644
--- a/conn.go
+++ b/conn.go
@@ -3,6 +3,7 @@ package openmldb
 import (
 	"bytes"
 	"context"
+	"database/sql"
 	"database/sql/driver"
 	"encoding/json"
 	"errors"
@@ -71,7 +72,7 @@ type queryResp struct {
 }
 
 type respData struct {
-	Schema []string             `json:"schema"`
+	Schema []string         `json:"schema"`
 	Data   [][]driver.Value `json:"data"`
 }
 
@@ -127,36 +128,48 @@ type queryReq struct {
 }
 
 type queryInput struct {
-	Schema []string           `json:"schema"`
+	Schema []string       `json:"schema"`
 	Data   []driver.Value `json:"data"`
 }
 
-func marshalQueryRequest(mode, sql string, input ...driver.Value) ([]byte, error) {
+func marshalQueryRequest(mode string, sqlStr string, input ...driver.Value) ([]byte, error) {
 	req := queryReq{
 		Mode: mode,
-		SQL:  sql,
+		SQL:  sqlStr,
 	}
 
+	// TODO(someone): Type infer from input slice does not work always. Consider those cases:
+	// 1. a int type can be a int32 or int64, depends on value size.
+	// 2. we're not covering more input types like uint.
+	// 3. For a int16 or int32 input from DB.Query(...), it always convert to int64 because driver.Value
+	//    only expect int64 from primitive types.
+	//
+	// A better approach is to ask the schema types from api server, which in turn ask types info to SQL compiler.
+
 	if len(input) > 0 {
 		schema := make([]string, len(input))
+		// TODO(someone): support value as nil, at current time it is not possible to infer SQL type from a nil
 		for i, v := range input {
-			switch v.(type) {
-			case bool:
+			switch vv := v.(type) {
+			case bool, Null[bool]:
 				schema[i] = "bool"
-			case int16:
+			case int16, Null[int16]:
 				schema[i] = "int16"
-			case int32:
+			case int32, Null[int32]:
 				schema[i] = "int32"
-			case int64:
+			case int64, Null[int64]:
 				schema[i] = "int64"
-			case float32:
+			case float32, Null[float32]:
 				schema[i] = "float"
-			case float64:
+			case float64, Null[float64]:
 				schema[i] = "double"
-			case string:
+			case string, Null[string]:
 				schema[i] = "string"
 			case time.Time:
 				schema[i] = "timestamp"
+				input[i] = Null[time.Time]{Null: sql.Null[time.Time]{V: vv, Valid: true}}
+			case Null[time.Time]:
+				schema[i] = "timestamp"
 			case NullDate:
 				schema[i] = "date"
 			default:
@@ -179,8 +192,14 @@ func unmarshalQueryResponse(respBody io.Reader) (*queryResp, error) {
 	}
 
 	if r.Data != nil {
+		// queryResp.Data may nil for DDL
 		for _, row := range r.Data.Data {
 			for i, col := range row {
+				if col == nil {
+					row[i] = nil
+					continue
+				}
+
 				switch strings.ToLower(r.Data.Schema[i]) {
 				case "bool":
 					row[i] = col.(bool)
@@ -196,14 +215,17 @@ func unmarshalQueryResponse(respBody io.Reader) (*queryResp, error) {
 					row[i] = float64(col.(float64))
 				case "string":
 					row[i] = col.(string)
+				// date and timestamp values saved internally as time.Time
 				case "timestamp":
 					// timestamp value returned as int64 millisecond unix epoch time
 					row[i] = time.UnixMilli(int64(col.(float64)))
 				case "date":
-					// date values returned as "YYYY-mm-dd" formated string
-					var nullDate NullDate
-					nullDate.Scan(col.(string))
-					row[i] = nullDate
+					t, err := parseDateStr(col.(string))
+					if err != nil {
+						row[i] = nil
+					}
+
+					row[i] = t
 				default:
 					return nil, fmt.Errorf("unknown type %s at index %d", r.Data.Schema[i], i)
 				}
@@ -244,7 +266,7 @@ func (c *conn) execute(ctx context.Context, sql string, parameters ...driver.Val
 	if r, err := unmarshalQueryResponse(resp.Body); err != nil {
 		return nil, err
 	} else if r.Code != 0 {
-		return nil, fmt.Errorf("conn error: %s", r.Msg)
+		return nil, fmt.Errorf("execute error: %s", r.Msg)
 	} else if r.Data != nil {
 		return &respDataRows{*r.Data, 0}, nil
 	}
diff --git a/conn_test.go b/conn_test.go
index 560c7a5..2c1c2be 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -1,9 +1,11 @@
 package openmldb
 
 import (
-	interfaces "database/sql/driver"
+	"database/sql"
+	"database/sql/driver"
 	"strings"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 )
@@ -12,28 +14,36 @@ func TestParseReqToJson(t *testing.T) {
 	for _, tc := range []struct {
 		mode   string
 		sql    string
-		input  []interfaces.Value
+		input  []driver.Value
 		expect string
 	}{
 		{
-			"offsync",
+			"offline",
 			"SELECT 1;",
 			nil,
 			`{
-				"mode": "offsync",
+				"mode": "offline",
 				"sql": "SELECT 1;"
 			}`,
 		},
 		{
-			"offsync",
+			"online",
 			"SELECT c1, c2 FROM demo WHERE c1 = ? AND c2 = ?;",
-			[]interfaces.Value{int32(1), "bb"},
+			[]driver.Value{
+				int16(2), // int16
+				int32(1), // int32
+				"bb",     // string
+				Null[string]{Null: sql.Null[string]{V: "foo", Valid: true}}, // string
+				time.UnixMilli(8000), // timestamp
+				Null[time.Time]{Null: sql.Null[time.Time]{V: time.UnixMilli(4000), Valid: true}},                              // timestamp
+				Null[time.Time]{Null: sql.Null[time.Time]{V: time.UnixMilli(4000), Valid: false}},                             // timestamp
+				NullDate{Null: sql.Null[time.Time]{V: time.Date(2022, time.October, 10, 0, 0, 0, 0, time.UTC), Valid: true}}}, // date
 			`{
-				"mode": "offsync",
+				"mode": "online",
 				"sql": "SELECT c1, c2 FROM demo WHERE c1 = ? AND c2 = ?;",
 				"input": {
-					"schema": ["int32", "string"],
-					"data": [1, "bb"]
+					"schema": ["int16", "int32", "string", "string", "timestamp", "timestamp", "timestamp", "date"],
+					"data": [2, 1, "bb", "foo", 8000, 4000, null, "2022-10-10"]
 				}
 			}`,
 		},
@@ -60,6 +70,24 @@ func TestParseRespFromJson(t *testing.T) {
 				Data: nil,
 			},
 		},
+		{
+			`{
+				"code": 0,
+				"msg": "ok",
+				"data": {
+					"schema": ["date", "string"],
+					"data": []
+				}
+			}`,
+			queryResp{
+				Code: 0,
+				Msg:  "ok",
+				Data: &respData{
+					Schema: []string{"date", "string"},
+					Data:   [][]driver.Value{},
+				},
+			},
+		},
 		{
 			`{
 				"code": 0,
@@ -74,7 +102,7 @@ func TestParseRespFromJson(t *testing.T) {
 				Msg:  "ok",
 				Data: &respData{
 					Schema: []string{"Int32", "String"},
-					Data: [][]interfaces.Value{
+					Data: [][]driver.Value{
 						{int32(1), "bb"},
 						{int32(2), "bb"},
 					},
@@ -95,7 +123,7 @@ func TestParseRespFromJson(t *testing.T) {
 				Msg:  "ok",
 				Data: &respData{
 					Schema: []string{"Bool", "Int16", "Int32", "Int64", "Float", "Double", "String"},
-					Data: [][]interfaces.Value{
+					Data: [][]driver.Value{
 						{true, int16(1), int32(1), int64(1), float32(1), float64(1), "bb"},
 					},
 				},
diff --git a/encode.go b/encode.go
new file mode 100644
index 0000000..c1d7547
--- /dev/null
+++ b/encode.go
@@ -0,0 +1,15 @@
+package openmldb
+
+import (
+	"time"
+)
+
+func parseDateStr(src string) (time.Time, error) {
+	// api server returns date type as string formatted 'yyyy-mm-dd'
+	dval, err := time.Parse(time.DateOnly, src)
+	if err != nil {
+		return time.Time{}, err
+	}
+
+	return dval, nil
+}
diff --git a/go.mod b/go.mod
index e57fa16..bffc635 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
 module github.com/4paradigm/openmldb-go-sdk
 
-go 1.18
+go 1.22
 
 require github.com/stretchr/testify v1.9.0
 
diff --git a/go_sdk_test.go b/go_sdk_test.go
index 4dff6fa..0fa8401 100644
--- a/go_sdk_test.go
+++ b/go_sdk_test.go
@@ -18,95 +18,196 @@ import (
 
 var apiServer string
 
-// 1. NullTime + NullDate
-// 2. Time + Time
+var db *sql.DB
+var ctx context.Context
 
-func Test_driver(t *testing.T) {
-	db, err := sql.Open("openmldb", fmt.Sprintf("openmldb://%s/test_db", apiServer))
-	if err != nil {
-		t.Errorf("fail to open connect: %s", err)
-	}
+// user may use sql.NullXXX types to represent SQL values that may be null
 
-	defer func() {
-		if err := db.Close(); err != nil {
-			t.Errorf("fail to close connection: %s", err)
-		}
-	}()
+type demoStruct1 struct {
+	c1 int32
+	c2 string
+	ts time.Time
+	dt time.Time
+}
+type demoStruct2 struct {
+	c1 sql.NullInt32
+	c2 sql.NullString
+	ts sql.NullTime
+	dt sql.NullTime
+}
+type demoStruct3 struct {
+	c1 openmldb.Null[int32]
+	c2 openmldb.Null[string]
+	ts openmldb.Null[time.Time]
+	dt openmldb.NullDate
+}
 
-	ctx := context.Background()
+func TestPingCtx(t *testing.T) {
 	assert.NoError(t, db.PingContext(ctx), "fail to ping connect")
+}
+
+func TestQuery1(t *testing.T) {
+	// use time.Time to represent both timestamp and date
+	queryStmt := `SELECT * FROM demo`
+	rows, err := db.QueryContext(ctx, queryStmt)
+	assert.NoError(t, err, "fail to query %s", queryStmt)
+
+	var demo demoStruct1
+	{
+		assert.True(t, rows.Next())
+		assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt))
+		assert.Equal(t, demoStruct1{1, "bb", time.UnixMilli(3000), time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC)}, demo)
+	}
+	// {
+	// 	assert.True(t, rows.Next())
+	// 	assert.NoError(t, rows.Scan(&demo.c1, &demo.c2))
+	// 	assert.Equal(t, struct {
+	// 		c1 int32
+	// 		c2 string
+	// 	}{2, "bb"}, demo)
+	// }
+}
+
+func TestQuery2(t *testing.T) {
+	// use sql.NullTime to represent both timestamp and date
+	queryStmt := `SELECT * FROM demo`
+	rows, err := db.QueryContext(ctx, queryStmt)
+	assert.NoError(t, err, "fail to query %s", queryStmt)
+
+	var demo demoStruct2
+	assert.True(t, rows.Next())
+	assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt))
+	assert.Equal(t, sql.NullInt32{Int32: 1, Valid: true}, demo.c1)
+	assert.Equal(t, sql.NullString{String: "bb", Valid: true}, demo.c2)
+	assert.Equal(t, sql.NullTime{Time: time.UnixMilli(3000), Valid: true}, demo.ts)
+	assert.Equal(t, sql.NullTime{Time: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}, demo.dt)
+}
+
+func TestQuery3(t *testing.T) {
+	// use openmldb.Null[T] and openmldb.NullDate to represent timestamp and date
+	queryStmt := `SELECT * FROM demo`
+	rows, err := db.QueryContext(ctx, queryStmt)
+	assert.NoError(t, err, "fail to query %s", queryStmt)
+
+	var demo demoStruct3
+	assert.True(t, rows.Next())
+	assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt))
+	assert.Equal(t, openmldb.Null[int32]{Null: sql.Null[int32]{V: 1, Valid: true}}, demo.c1)
+	assert.Equal(t, openmldb.Null[string]{Null: sql.Null[string]{V: "bb", Valid: true}}, demo.c2)
+	assert.Equal(t, openmldb.Null[time.Time]{Null: sql.Null[time.Time]{V: time.UnixMilli(3000), Valid: true}}, demo.ts)
+	assert.Equal(t, openmldb.NullDate{Null: sql.Null[time.Time]{V: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, demo.dt)
+}
+
+func TestQueryWithParams(t *testing.T) {
+	parameterQueryStmt := `SELECT * FROM demo WHERE c2 = ? AND c1 = ? AND ts = ?;`
+	rows, err := db.QueryContext(ctx, parameterQueryStmt, "bb", 1, time.UnixMilli(3000))
+	assert.NoError(t, err, "fail to query %s", parameterQueryStmt)
+
+	var demo demoStruct1
+	{
+		assert.True(t, rows.Next())
+		assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt))
+		assert.Equal(t, demoStruct1{1, "bb", time.UnixMilli(3000), time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC)}, demo)
+	}
+}
+
+func TestQueryWithParamsExpectsNull(t *testing.T) {
+	_, err := db.ExecContext(ctx, "create table test2 (id int16, val int64, dt date)")
+	assert.NoError(t, err)
+	t.Cleanup(func() {
+		_, err := db.ExecContext(ctx, "drop table test2")
+		assert.NoError(t, err)
+	})
+
+	{
+		_, err := db.ExecContext(ctx, "insert into test2 values (1, NULL, NULL)")
+		assert.NoError(t, err)
+	}
+
+	rows, err := db.QueryContext(ctx, "select * from test2 where id = ?", 1)
+	assert.NoError(t, err)
+	var demo struct {
+		id  sql.NullInt16
+		val sql.NullInt64
+		dt  sql.NullTime
+	}
+	{
+		assert.True(t, rows.Next())
+		assert.NoError(t, rows.Scan(&demo.id, &demo.val, &demo.dt))
+		assert.Equal(t, sql.NullInt16{Int16: 1, Valid: true}, demo.id)
+		assert.Equal(t, sql.NullInt64{Int64: 0, Valid: false}, demo.val)
+		assert.Equal(t, sql.NullTime{Time: time.Time{}, Valid: false}, demo.dt)
+	}
+}
+
+func TestQueryWithParamsResultsEmpty(t *testing.T) {
+	_, err := db.ExecContext(ctx, "create table test3 (id int16, val int64, dt date)")
+	assert.NoError(t, err)
+	t.Cleanup(func() {
+		_, err := db.ExecContext(ctx, "drop table test3")
+		assert.NoError(t, err)
+	})
+
+	{
+		_, err := db.ExecContext(ctx, "insert into test3 values (1, NULL, NULL)")
+		assert.NoError(t, err)
+	}
+
+	{
+		rows, err := db.QueryContext(ctx, "select * from test3 where id = ?", int16(10))
+		assert.NoError(t, err)
+		assert.False(t, rows.Next())
+	}
+
+	{
+		// disabled since https://github.com/4paradigm/OpenMLDB/issues/3902
+		// _, err := db.QueryContext(ctx, "select * from test3 where id = ?",
+		// 	openmldb.Null[int16]{Null: sql.Null[int16]{V: 0, Valid: false}})
+		// assert.NoError(t, err)
+		// assert.False(t, rows.Next())
+	}
+}
+
+func PrepareAndRun(m *testing.M) int {
+	var err error
+	db, err = sql.Open("openmldb", fmt.Sprintf("openmldb://%s/test_db", apiServer))
+	if err != nil {
+		fmt.Fprintf(os.Stderr, "fail to open connect: %s", err)
+		os.Exit(1)
+	}
+
+	ctx = context.Background()
 
 	{
 		createTableStmt := "CREATE TABLE demo(c1 int, c2 string, ts timestamp, dt date);"
 		_, err := db.ExecContext(ctx, createTableStmt)
-		assert.NoError(t, err, "fail to exec %s", createTableStmt)
+		if err != nil {
+			fmt.Fprintf(os.Stderr, "fail to exec %s", createTableStmt)
+			os.Exit(1)
+		}
 	}
 
 	defer func() {
 		dropTableStmt := "DROP TABLE demo;"
 		_, err := db.ExecContext(ctx, dropTableStmt)
 		if err != nil {
-			t.Errorf("fail to drop table: %s", err)
+			fmt.Fprintf(os.Stderr, "fail to drop table: %s", err)
+			os.Exit(1)
 		}
 	}()
-
 	{
 		// FIXME: ordering issue
 		insertValueStmt := `INSERT INTO demo VALUES (1, "bb", 3000, "2022-12-12");`
 		// insertValueStmt := `INSERT INTO demo VALUES (1, "bb"), (2, "bb");`
 		_, err := db.ExecContext(ctx, insertValueStmt)
-		assert.NoError(t, err, "fail to exec %s", insertValueStmt)
-	}
-
-	t.Run("query", func(t *testing.T) {
-		queryStmt := `SELECT * FROM demo`
-		rows, err := db.QueryContext(ctx, queryStmt)
-		assert.NoError(t, err, "fail to query %s", queryStmt)
-
-		var demo struct {
-			c1 int32
-			c2 string
-			ts time.Time
-			dt openmldb.NullDate
-		}
-		{
-			assert.True(t, rows.Next())
-			assert.NoError(t, rows.Scan(&demo.c1, &demo.c2, &demo.ts, &demo.dt))
-			assert.Equal(t, struct {
-				c1 int32
-				c2 string
-				ts time.Time
-				dt openmldb.NullDate
-			}{1, "bb", time.UnixMilli(3000), openmldb.NullDate{Time: time.Date(2022, time.December, 12, 0, 0, 0, 0, time.UTC), Valid: true}}, demo)
+		if err != nil {
+			fmt.Fprintf(os.Stderr, "fail to exec: %s", insertValueStmt)
+			os.Exit(1)
 		}
-		// {
-		// 	assert.True(t, rows.Next())
-		// 	assert.NoError(t, rows.Scan(&demo.c1, &demo.c2))
-		// 	assert.Equal(t, struct {
-		// 		c1 int32
-		// 		c2 string
-		// 	}{2, "bb"}, demo)
-		// }
-	})
+	}
 
-	t.Run("query with parameter", func(t *testing.T) {
-		parameterQueryStmt := `SELECT c1, c2 FROM demo WHERE c2 = ? AND c1 = ?;`
-		rows, err := db.QueryContext(ctx, parameterQueryStmt, "bb", 1)
-		assert.NoError(t, err, "fail to query %s", parameterQueryStmt)
+	return m.Run()
 
-		var demo struct {
-			c1 int32
-			c2 string
-		}
-		{
-			assert.True(t, rows.Next())
-			assert.NoError(t, rows.Scan(&demo.c1, &demo.c2))
-			assert.Equal(t, struct {
-				c1 int32
-				c2 string
-			}{1, "bb"}, demo)
-		}
-	})
 }
 
 func TestMain(m *testing.M) {
@@ -117,5 +218,5 @@ func TestMain(m *testing.M) {
 		log.Fatalf("non-empty api server address required")
 	}
 
-	os.Exit(m.Run())
+	os.Exit(PrepareAndRun(m))
 }
diff --git a/types.go b/types.go
index bb009cf..f4ad1d1 100644
--- a/types.go
+++ b/types.go
@@ -1,49 +1,57 @@
 package openmldb
 
+// TODO(someone): support go < 1.22
+
 import (
 	"database/sql"
 	"database/sql/driver"
-	"errors"
+	"encoding/json"
 	"time"
 )
 
 var (
-	_ sql.Scanner = (*NullDate)(nil)
+	_ sql.Scanner   = (*NullDate)(nil)
 	_ driver.Valuer = NullDate{}
 )
 
+// Null represents a value that may be null.
+//
+// declare type embedded sql.Null so we still able to
+// utilize sql.Scanner and driver.Valuer in go standard,
+// and customize marshal logic for api requests
+type Null[T any] struct {
+	sql.Null[T]
+}
+
+// NullDate represents nullable SQL date in go
+//
+// embedded sql.Null[time.Time] so it by default
+// implements sql.Scanner and driver.Valuer, but
+// distinct timestamp representation in sdk.
 type NullDate struct {
-	Time  time.Time
-	Valid bool // Valid is true if Time is not NULL
+	sql.Null[time.Time]
 }
 
-// Scan implements sql.Scanner for NullDate
-func (dt *NullDate) Scan(src any) error {
-	switch val := src.(type) {
-	case string:
-		dval, err := time.Parse(time.DateOnly, val)
-		if err != nil {
-			dt.Valid = false
-			return err
-		} else {
-			dt.Time = dval
-			dt.Valid = true
-			return nil
-		}
-	case NullDate:
-		*dt = val
-		return nil
-	default:
-		return errors.New("scan NullDate from unsupported type")
+// MarshalJSON implements json.Marshaler
+func (src NullDate) MarshalJSON() ([]byte, error) {
+	if !src.Valid {
+		return json.Marshal(nil)
 	}
-
+	return json.Marshal(src.V.Format(time.DateOnly))
 }
 
-// Value implements driver.Valuer for NullDate
-func (dt NullDate) Value() (driver.Value, error) {
-	if !dt.Valid {
-		return nil, nil
+// MarshalJSON implements json.Marshaler for Null[T]
+func (src Null[T]) MarshalJSON() ([]byte, error) {
+	if !src.Valid {
+		return json.Marshal(nil)
 	}
-	return dt.Time, nil
 
+	var v any = src.V
+	switch val := v.(type) {
+	case time.Time:
+		// timestamp, marshal to int64 unix epoch time in millisecond
+		return json.Marshal(val.UnixMilli())
+	default:
+		return json.Marshal(src.V)
+	}
 }