From 280c0b7626871bbdcb4d6f4d25aaa63c8c88e468 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Fri, 2 Jun 2023 17:52:53 +1000 Subject: [PATCH 01/13] update tests --- conn.go | 31 ++- driver.go | 4 +- stmt.go | 31 +-- stmt_collection_parsing_test.go | 266 +++++++++++++++++++++ stmt_database_parsing_test.go | 167 ++++++++++++++ stmt_document_parsing_test.go | 1 + stmt_test.go | 393 -------------------------------- 7 files changed, 475 insertions(+), 418 deletions(-) create mode 100644 stmt_collection_parsing_test.go create mode 100644 stmt_database_parsing_test.go create mode 100644 stmt_document_parsing_test.go diff --git a/conn.go b/conn.go index 4068941..7bd0e8a 100644 --- a/conn.go +++ b/conn.go @@ -1,6 +1,7 @@ package gocosmos import ( + "context" "database/sql/driver" "errors" "time" @@ -10,29 +11,43 @@ var ( locGmt, _ = time.LoadLocation("GMT") ) -// Conn is Azure CosmosDB connection handle. +// Conn is Azure Cosmos DB implementation of driver.Conn. type Conn struct { - restClient *RestClient // Azure CosmosDB REST API client. + restClient *RestClient // Azure Cosmos DB REST API client. defaultDb string // default database used in Cosmos DB operations. } -// Prepare implements driver.Conn.Prepare. +// Prepare implements driver.Conn/Prepare. func (c *Conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +// PrepareContext implements driver.ConnPrepareContext/PrepareContext. +// +// @Available since v0.3.0 +func (c *Conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) { return parseQueryWithDefaultDb(c, c.defaultDb, query) } -// Close implements driver.Conn.Close. +// Close implements driver.Conn/Close. func (c *Conn) Close() error { return nil } -// Begin implements driver.Conn.Begin. +// Begin implements driver.Conn/Begin. func (c *Conn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +// BeginTx implements driver.ConnBeginTx/BeginTx. +// +// @Available since v0.3.0 +func (c *Conn) BeginTx(_ context.Context, _ driver.TxOptions) (driver.Tx, error) { return nil, errors.New("transaction is not supported") } -// CheckNamedValue implements driver.NamedValueChecker.CheckNamedValue. -func (c *Conn) CheckNamedValue(value *driver.NamedValue) error { - // since CosmosDB is document db, it accepts any value types +// CheckNamedValue implements driver.NamedValueChecker/CheckNamedValue. +func (c *Conn) CheckNamedValue(_ *driver.NamedValue) error { + // since Cosmos DB is document db, it accepts any value types return nil } diff --git a/driver.go b/driver.go index c45e544..a38b06e 100644 --- a/driver.go +++ b/driver.go @@ -67,11 +67,11 @@ var ( ErrConflict = errors.New("StatusCode=409 Conflict") ) -// Driver is Azure CosmosDB driver for database/sql. +// Driver is Azure Cosmos DB implementation of driver.Driver. type Driver struct { } -// Open implements driver.Driver.Open. +// Open implements driver.Driver/Open. // // connStr is expected in the following format: // diff --git a/stmt.go b/stmt.go index 33361b1..47b741d 100644 --- a/stmt.go +++ b/stmt.go @@ -11,24 +11,25 @@ const ( field = `([\w\-]+)` ifNotExists = `(\s+IF\s+NOT\s+EXISTS)?` ifExists = `(\s+IF\s+EXISTS)?` - with = `((\s+WITH\s+([\w-]+)\s*=\s*([\w/\.,;:'"-]+))*)` + with = `(\s+WITH\s+` + field + `\s*=\s*([\w/\.\*,;:'"-]+)((\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `\s*=\s*([\w/\.\*,;:'"-]+))*)?` + // with = `((\s+WITH\s+([\w-]+)\s*=\s*([\w/\.,;:'"-]+))*)` ) var ( - reCreateDb = regexp.MustCompile(`(?is)^CREATE\s+DATABASE` + ifNotExists + `\s+` + field + with + `$`) - reAlterDb = regexp.MustCompile(`(?is)^ALTER\s+DATABASE` + `\s+` + field + with + `$`) - reDropDb = regexp.MustCompile(`(?is)^DROP\s+DATABASE` + ifExists + `\s+` + field + `$`) - reListDbs = regexp.MustCompile(`(?is)^LIST\s+DATABASES?$`) - - reCreateColl = regexp.MustCompile(`(?is)^CREATE\s+(COLLECTION|TABLE)` + ifNotExists + `\s+(` + field + `\.)?` + field + with + `$`) - reAlterColl = regexp.MustCompile(`(?is)^ALTER\s+(COLLECTION|TABLE)` + `\s+(` + field + `\.)?` + field + with + `$`) - reDropColl = regexp.MustCompile(`(?is)^DROP\s+(COLLECTION|TABLE)` + ifExists + `\s+(` + field + `\.)?` + field + `$`) - reListColls = regexp.MustCompile(`(?is)^LIST\s+(COLLECTIONS?|TABLES?)(\s+FROM\s+` + field + `)?$`) - - reInsert = regexp.MustCompile(`(?is)^(INSERT|UPSERT)\s+INTO\s+(` + field + `\.)?` + field + `\s*\(([^)]*?)\)\s*VALUES\s*\(([^)]*?)\)$`) - reSelect = regexp.MustCompile(`(?is)^SELECT\s+(CROSS\s+PARTITION\s+)?.*?\s+FROM\s+` + field + `.*?` + with + `$`) - reUpdate = regexp.MustCompile(`(?is)^UPDATE\s+(` + field + `\.)?` + field + `\s+SET\s+(.*)\s+WHERE\s+id\s*=\s*(.*)$`) - reDelete = regexp.MustCompile(`(?is)^DELETE\s+FROM\s+(` + field + `\.)?` + field + `\s+WHERE\s+id\s*=\s*(.*)$`) + reCreateDb = regexp.MustCompile(`(?im)^CREATE\s+DATABASE` + ifNotExists + `\s+` + field + with + `$`) + reAlterDb = regexp.MustCompile(`(?im)^ALTER\s+DATABASE` + `\s+` + field + with + `$`) + reDropDb = regexp.MustCompile(`(?im)^DROP\s+DATABASE` + ifExists + `\s+` + field + `$`) + reListDbs = regexp.MustCompile(`(?im)^LIST\s+DATABASES?$`) + + reCreateColl = regexp.MustCompile(`(?im)^CREATE\s+(COLLECTION|TABLE)` + ifNotExists + `\s+(` + field + `\.)?` + field + with + `$`) + reAlterColl = regexp.MustCompile(`(?im)^ALTER\s+(COLLECTION|TABLE)` + `\s+(` + field + `\.)?` + field + with + `$`) + reDropColl = regexp.MustCompile(`(?im)^DROP\s+(COLLECTION|TABLE)` + ifExists + `\s+(` + field + `\.)?` + field + `$`) + reListColls = regexp.MustCompile(`(?im)^LIST\s+(COLLECTIONS?|TABLES?)(\s+FROM\s+` + field + `)?$`) + + reInsert = regexp.MustCompile(`(?im)^(INSERT|UPSERT)\s+INTO\s+(` + field + `\.)?` + field + `\s*\(([^)]*?)\)\s*VALUES\s*\(([^)]*?)\)$`) + reSelect = regexp.MustCompile(`(?im)^SELECT\s+(CROSS\s+PARTITION\s+)?.*?\s+FROM\s+` + field + `.*?` + with + `$`) + reUpdate = regexp.MustCompile(`(?im)^UPDATE\s+(` + field + `\.)?` + field + `\s+SET\s+(.*)\s+WHERE\s+id\s*=\s*(.*)$`) + reDelete = regexp.MustCompile(`(?im)^DELETE\s+FROM\s+(` + field + `\.)?` + field + `\s+WHERE\s+id\s*=\s*(.*)$`) ) func parseQuery(c *Conn, query string) (driver.Stmt, error) { diff --git a/stmt_collection_parsing_test.go b/stmt_collection_parsing_test.go new file mode 100644 index 0000000..5d711fd --- /dev/null +++ b/stmt_collection_parsing_test.go @@ -0,0 +1,266 @@ +package gocosmos + +import ( + "reflect" + "testing" +) + +func TestStmtCreateCollection_parse(t *testing.T) { + testName := "TestStmtCreateCollection_parse" + testData := []struct { + name string + sql string + expected *StmtCreateCollection + mustError bool + }{ + {name: "error_no_pk", sql: "CREATE collection db.coll", mustError: true}, + {name: "error_pk_and_large_pk", sql: "CREATE collection db.coll WITH pk=/a WITH largepk=/b", mustError: true}, + {name: "error_invalid_pk", sql: "CREATE collection db.coll WITH pk=", mustError: true}, + {name: "error_invalid_large_pk", sql: "CREATE collection db.coll WITH largepk=", mustError: true}, + {name: "error_ru_and_maxru", sql: "CREATE collection db.coll WITH pk=/id WITH ru=400 WITH maxru=1000", mustError: true}, + {name: "error_invalid_ru", sql: "create TABLE db.coll WITH pk=/id WITH ru=-1 WITH maxru=1000", mustError: true}, + {name: "error_invalid_maxru", sql: "CREATE COLLECTION db.coll WITH pk=/id WITH ru=400 WITH maxru=-1", mustError: true}, + {name: "error_invalid_ru2", sql: "CREATE TABLE db.table WITH pk=/id WITH ru=-1", mustError: true}, + {name: "error_invalid_maxru2", sql: "CREATE COLLECTION db.table WITH pk=/id WITH maxru=-1", mustError: true}, + {name: "error_no_collection", sql: "CREATE TABLE db WITH pk=/id", mustError: true}, + {name: "error_if_not_exist", sql: "CREATE TABLE IF NOT EXIST db.table WITH pk=/id", mustError: true}, + + {name: "basic", sql: "CREATE COLLECTION db1.table1 WITH pk=/id", expected: &StmtCreateCollection{dbName: "db1", collName: "table1", pk: "/id"}}, + {name: "table_with_ru", sql: "create\ntable\rdb-2.table_2 WITH\tPK=/email WITH\r\nru=100", expected: &StmtCreateCollection{dbName: "db-2", collName: "table_2", pk: "/email", ru: 100}}, + {name: "if_not_exists_large_pk_with_maxru", sql: "CREATE collection\nIF\rNOT\t\nEXISTS\n\tdb_3.table-3 with largePK=/id WITH\t\rmaxru=100", expected: &StmtCreateCollection{dbName: "db_3", collName: "table-3", ifNotExists: true, isLargePk: true, pk: "/id", maxru: 100}}, + {name: "table_if_not_exists_large_pk_with_uk", sql: "create TABLE if not exists db-0_1.table_0-1 WITH LARGEpk=/a/b/c with uk=/a:/b,/c/d;/e/f/g", expected: &StmtCreateCollection{dbName: "db-0_1", collName: "table_0-1", ifNotExists: true, isLargePk: true, pk: "/a/b/c", uk: [][]string{{"/a"}, {"/b", "/c/d"}, {"/e/f/g"}}}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtCreateCollection) + if !ok { + t.Fatalf("%s failed: expected StmtCreateCollection but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.withOptsStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtCreateCollection_parse_defaultDb(t *testing.T) { + testName := "TestStmtCreateCollection_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtCreateCollection + mustError bool + }{ + {name: "error_invalid_query", db: "mydb", sql: "CREATE TABLE .mytable WITH pk=/id", mustError: true}, + + {name: "basic", db: "mydb", sql: "CREATE COLLECTION table1 WITH pk=/id", expected: &StmtCreateCollection{dbName: "mydb", collName: "table1", pk: "/id"}}, + {name: "db_in_query", db: "mydb", sql: "create\ntable\r\ndb2.table_2 WITH\r\t\nPK=/email WITH\nru=100", expected: &StmtCreateCollection{dbName: "db2", collName: "table_2", pk: "/email", ru: 100}}, + {name: "if_not_exists", db: "mydb", sql: "CREATE collection\nIF\nNOT\t\nEXISTS\n\ttable-3 with largePK=/id WITH\tmaxru=100", expected: &StmtCreateCollection{dbName: "mydb", collName: "table-3", ifNotExists: true, isLargePk: true, pk: "/id", maxru: 100}}, + {name: "db_in_query_if_not_exists", db: "mydb", sql: "create TABLE if not exists db3.table_0-1 WITH LARGEpk=/a/b/c with uk=/a:/b,/c/d;/e/f/g", expected: &StmtCreateCollection{dbName: "db3", collName: "table_0-1", ifNotExists: true, isLargePk: true, pk: "/a/b/c", uk: [][]string{{"/a"}, {"/b", "/c/d"}, {"/e/f/g"}}}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtCreateCollection) + if !ok { + t.Fatalf("%s failed: expected StmtCreateCollection but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.withOptsStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtAlterCollection_parse(t *testing.T) { + testName := "TestStmtAlterCollection_parse" + testData := []struct { + name string + sql string + expected *StmtAlterCollection + mustError bool + }{ + {name: "error_no_ru_maxru", sql: "ALTER collection db.coll", mustError: true}, + {name: "error_no_db", sql: "ALTER collection coll WITH ru=400", mustError: true}, + {name: "error_invalid_query", sql: "ALTER collection .coll WITH maxru=4000", mustError: true}, + {name: "error_ru_and_maxru", sql: "alter TABLE db.coll WITH ru=400 WITH maxru=4000", mustError: true}, + {name: "error_invalid_ru", sql: "alter TABLE db.coll WITH ru=-1", mustError: true}, + {name: "error_invalid_maxru", sql: "alter TABLE db.coll WITH maxru=-1", mustError: true}, + + {name: "basic", sql: "ALTER collection db1.table1 WITH ru=400", expected: &StmtAlterCollection{dbName: "db1", collName: "table1", ru: 400}}, + {name: "table", sql: "alter\nTABLE\rdb-2.table_2 WITH\tmaxru=40000", expected: &StmtAlterCollection{dbName: "db-2", collName: "table_2", maxru: 40000}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtAlterCollection) + if !ok { + t.Fatalf("%s failed: expected StmtAlterCollection but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.withOptsStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtAlterCollection_parse_defaultDb(t *testing.T) { + testName := "TestStmtAlterCollection_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtAlterCollection + mustError bool + }{ + {name: "error_invalid_query", db: "mydb", sql: "ALTER COLLECTION .mytable WITH ru=400", mustError: true}, + {name: "error_notable", db: "mydb", sql: "ALTER COLLECTION mydb. WITH ru=400", mustError: true}, + {name: "error_no_db_table", db: "mydb", sql: "ALTER COLLECTION WITH ru=400", mustError: true}, + + {name: "basic", db: "mydb", sql: "ALTER collection table1 WITH ru=400", expected: &StmtAlterCollection{dbName: "mydb", collName: "table1", ru: 400}}, + {name: "db_in_query", db: "mydb", sql: "alter\nTABLE\rdb-2.table_2 WITH\tmaxru=40000", expected: &StmtAlterCollection{dbName: "db-2", collName: "table_2", maxru: 40000}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtAlterCollection) + if !ok { + t.Fatalf("%s failed: expected StmtAlterCollection but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.withOptsStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtDropCollection_parse(t *testing.T) { + testName := "TestStmtDropCollection_parse" + testData := []struct { + name string + sql string + expected *StmtDropCollection + mustError bool + }{ + {name: "error_no_collection", sql: "DROP collection db", mustError: true}, + {name: "error_no_collection2", sql: "Drop Table db.", mustError: true}, + {name: "error_invalid_query", sql: "DROP COLLECTION .mytable", mustError: true}, + {name: "error_if_not_exists", sql: "DROP COLLECTION IF NOT EXISTS mydb.mytable", mustError: true}, + {name: "error_if_exist", sql: "DROP COLLECTION IF EXIST mydb.mytable", mustError: true}, + + {name: "basic", sql: "DROP \rCOLLECTION\n db1.table1", expected: &StmtDropCollection{dbName: "db1", collName: "table1"}}, + {name: "table", sql: "DROP\t\rtable\n\tdb-2.table_2", expected: &StmtDropCollection{dbName: "db-2", collName: "table_2"}}, + {name: "if_exists", sql: "drop \rcollection\n IF EXISTS \t db_3.table-3", expected: &StmtDropCollection{dbName: "db_3", ifExists: true, collName: "table-3"}}, + {name: "table_if_exists", sql: "Drop Table If Exists db-4_0.table_4-0", expected: &StmtDropCollection{dbName: "db-4_0", ifExists: true, collName: "table_4-0"}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtDropCollection) + if !ok { + t.Fatalf("%s failed: expected StmtDropCollection but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtDropCollection_parse_defaultDb(t *testing.T) { + testName := "TestStmtDropCollection_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtDropCollection + mustError bool + }{ + {name: "error_invalid_query", db: "mydb", sql: "DROP collection .mytable", mustError: true}, + {name: "error_if_not_exists", db: "mydb", sql: "DROP COLLECTION IF NOT EXISTS mydb.mytable", mustError: true}, + {name: "error_if_exists", db: "mydb", sql: "DROP COLLECTION IF EXIST mydb.mytable", mustError: true}, + + {name: "basic", db: "mydb", sql: "DROP COLLECTION table1", expected: &StmtDropCollection{dbName: "mydb", collName: "table1"}}, + {name: "db_in_query", db: "mydb", sql: "DROP\t\rtable\n\tdb-2.table_2", expected: &StmtDropCollection{dbName: "db-2", collName: "table_2"}}, + {name: "if_exists", db: "mydb", sql: "drop \tcollection\r IF EXISTS \n table-3", expected: &StmtDropCollection{dbName: "mydb", ifExists: true, collName: "table-3"}}, + {name: "table_if_exists", db: "mydb", sql: "Drop Table If Exists db-4_0.table_4-0", expected: &StmtDropCollection{dbName: "db-4_0", ifExists: true, collName: "table_4-0"}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtDropCollection) + if !ok { + t.Fatalf("%s failed: expected StmtDropCollection but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} diff --git a/stmt_database_parsing_test.go b/stmt_database_parsing_test.go new file mode 100644 index 0000000..c55221b --- /dev/null +++ b/stmt_database_parsing_test.go @@ -0,0 +1,167 @@ +package gocosmos + +import ( + "reflect" + "testing" +) + +func TestStmtCreateDatabase_parse(t *testing.T) { + testName := "TestStmtCreateDatabase_parse" + testData := []struct { + name string + sql string + expected *StmtCreateDatabase + mustError bool + }{ + {name: "error_no_table", sql: "CREATE DATABASE ", mustError: true}, + {name: "error_if_not_exists_no_table", sql: "CREATE DATABASE IF NOT EXISTS ", mustError: true}, + {name: "error_syntax", sql: "CREATE DATABASE db0 IF NOT EXISTS", mustError: true}, + {name: "error_if_exists", sql: "CREATE DATABASE if exists db0", mustError: true}, + {name: "error_if_not_exist", sql: "CREATE DATABASE IF NOT EXIST db0", mustError: true}, + {name: "error_ru_and_maxru", sql: "CREATE DATABASE db0 with RU=400, WITH MAXru=4000", mustError: true}, + + {name: "basic", sql: "CREATE DATABASE db1", expected: &StmtCreateDatabase{dbName: "db1"}}, + {name: "with_ru", sql: "create\ndatabase\n db-2 \nWITH \n ru=100", expected: &StmtCreateDatabase{dbName: "db-2", ru: 100}}, + {name: "with_max", sql: "CREATE\r\nDATABASE \n \r db_3 \r \n with\n\rmaxru=100", expected: &StmtCreateDatabase{dbName: "db_3", maxru: 100}}, + {name: "if_not_exists", sql: "CREATE DATABASE\tIF\rNOT\nEXISTS db-4-0", expected: &StmtCreateDatabase{dbName: "db-4-0", ifNotExists: true}}, + {name: "if_not_exists_with_ru", sql: "create\ndatabase IF NOT EXISTS db-5_0 with\nru=100", expected: &StmtCreateDatabase{dbName: "db-5_0", ifNotExists: true, ru: 100}}, + {name: "if_not_exists_with_maxru", sql: "CREATE DATABASE if not exists db_6-0 WITH maxru=100", expected: &StmtCreateDatabase{dbName: "db_6-0", ifNotExists: true, maxru: 100}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtCreateDatabase) + if !ok { + t.Fatalf("%s failed: expected StmtCreateDatabase but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.withOptsStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtAlterDatabase_parse(t *testing.T) { + testName := "TestStmtAlterDatabase_parse" + testData := []struct { + name string + sql string + expected *StmtAlterDatabase + mustError bool + }{ + {name: "error_no_ru_maxru", sql: "ALTER database db0", mustError: true}, + {name: "error_ru_and_maxru", sql: "ALTER database db0 WITH RU=400, WITH maxRU=4000", mustError: true}, + + {name: "with_ru", sql: "ALTER\rdatabase\ndb1\tWITH ru=400", expected: &StmtAlterDatabase{dbName: "db1", ru: 400}}, + {name: "with_maxru", sql: "alter DATABASE db-1 with maxru=4000", expected: &StmtAlterDatabase{dbName: "db-1", maxru: 4000}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtAlterDatabase) + if !ok { + t.Fatalf("%s failed: expected StmtAlterDatabase but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.withOptsStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtDropDatabase_parse(t *testing.T) { + testName := "TestStmtDropDatabase_parse" + testData := []struct { + name string + sql string + expected *StmtDropDatabase + mustError bool + }{ + {name: "error_if_exist", sql: "DROP DATABASE IF EXIST db1", mustError: true}, + {name: "error_no_db", sql: "DROP DATABASE ", mustError: true}, + + {name: "basic", sql: "DROP DATABASE db1", expected: &StmtDropDatabase{dbName: "db1"}}, + {name: "lfcr", sql: "DROP\ndatabase\rdb-2", expected: &StmtDropDatabase{dbName: "db-2"}}, + {name: "if_exists", sql: "drop\rdatabase\nIF\nEXISTS db_3", expected: &StmtDropDatabase{dbName: "db_3", ifExists: true}}, + {name: "if_exists_2", sql: "Drop Database \tIf\t Exists \t db-4_0", expected: &StmtDropDatabase{dbName: "db-4_0", ifExists: true}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtDropDatabase) + if !ok { + t.Fatalf("%s failed: expected StmtDropDatabase but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtListDatabases_parse(t *testing.T) { + testName := "TestStmtListDatabases_parse" + testData := []struct { + name string + sql string + expected *StmtListDatabases + mustError bool + }{ + {name: "basic", sql: "LIST DATABASES", expected: &StmtListDatabases{}}, + {name: "database", sql: " lisT \r\t\n Database ", expected: &StmtListDatabases{}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtListDatabases) + if !ok { + t.Fatalf("%s failed: expected StmtListDatabases but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} diff --git a/stmt_document_parsing_test.go b/stmt_document_parsing_test.go new file mode 100644 index 0000000..c7270a2 --- /dev/null +++ b/stmt_document_parsing_test.go @@ -0,0 +1 @@ +package gocosmos diff --git a/stmt_test.go b/stmt_test.go index b06cfc5..274c87c 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -1,7 +1,6 @@ package gocosmos import ( - "fmt" "reflect" "testing" ) @@ -37,400 +36,8 @@ func TestStmt_NumInput(t *testing.T) { } } -func Test_parseQuery_CreateDatabase(t *testing.T) { - name := "Test_parseQuery_CreateDatabase" - type testStruct struct { - dbName string - ifNotExists bool - ru, maxru int - } - testData := map[string]testStruct{ - "CREATE DATABASE db1": {dbName: "db1", ifNotExists: false, ru: 0, maxru: 0}, - "create database\ndb-2\r\nWITH ru=100": {dbName: "db-2", ifNotExists: false, ru: 100, maxru: 0}, - "CREATE\nDATABASE\r\ndb_3\nwith\r\nmaxru=100": {dbName: "db_3", ifNotExists: false, ru: 0, maxru: 100}, - "CREATE DATABASE\r\nIF NOT EXISTS\ndb-4-0": {dbName: "db-4-0", ifNotExists: true, ru: 0, maxru: 0}, - "create\ndatabase IF NOT EXISTS db-5_0 with\r\nru=100": {dbName: "db-5_0", ifNotExists: true, ru: 100, maxru: 0}, - "CREATE DATABASE if not exists db_6-0 WITH maxru=100": {dbName: "db_6-0", ifNotExists: true, ru: 0, maxru: 100}, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtCreateDatabase); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtCreateDatabase", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.ifNotExists != data.ifNotExists { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ifNotExists, dbstmt.ifNotExists) - } else if dbstmt.ru != data.ru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ru, dbstmt.ru) - } else if dbstmt.maxru != data.maxru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.maxru, dbstmt.maxru) - } - } - - invalidQueries := []string{ - "CREATE DATABASE dbtemp WITH ru=400 WITH maxru=1000", - "CREATE DATABASE dbtemp WITH ru=-1 WITH maxru=1000", - "CREATE DATABASE dbtemp WITH ru=400 WITH maxru=-1", - "CREATE DATABASE dbtemp WITH ru=-1", - "CREATE DATABASE dbtemp WITH maxru=-1", - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_AlterDatabase(t *testing.T) { - name := "Test_parseQuery_AlterDatabase" - type testStruct struct { - dbName string - ru, maxru int - } - testData := map[string]testStruct{ - "ALTER database db1 WITH ru=400": {dbName: "db1", ru: 400, maxru: 0}, - "alter DATABASE db-1 with maxru=4000": {dbName: "db-1", ru: 0, maxru: 4000}, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtAlterDatabase); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtAlterDatabase", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.ru != data.ru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ru, dbstmt.ru) - } else if dbstmt.maxru != data.maxru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.maxru, dbstmt.maxru) - } - } - - invalidQueries := []string{ - "ALTER DATABASE dbtemp", - "ALTER DATABASE dbtemp WITH ru=400 WITH maxru=4000", - "ALTER DATABASE dbtemp WITH ru=-1", - "ALTER DATABASE dbtemp WITH maxru=-1", - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - temp, _ := parseQuery(nil, query) - fmt.Printf("%#v\n", temp) - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_DropDatabase(t *testing.T) { - name := "Test_parseQuery_DropDatabase" - type testStruct struct { - dbName string - ifExists bool - } - testData := map[string]testStruct{ - "DROP DATABASE db1": {dbName: "db1", ifExists: false}, - "DROP\ndatabase\r\ndb-2": {dbName: "db-2", ifExists: false}, - "drop database\r\nIF\nEXISTS db_3": {dbName: "db_3", ifExists: true}, - "Drop Database If Exists db-4_0": {dbName: "db-4_0", ifExists: true}, - } - - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtDropDatabase); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtDropDatabase", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.ifExists != data.ifExists { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ifExists, dbstmt.ifExists) - } - } -} - -func Test_parseQuery_ListDatabases(t *testing.T) { - name := "Test_parseQuery_ListDatabases" - testData := []string{"LIST\nDATABASES", "list\r\n database"} - - for _, query := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if _, ok := stmt.(*StmtListDatabases); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtListDatabases", name+"/"+query) - } - } -} - /*----------------------------------------------------------------------*/ -func Test_parseQuery_CreateCollection(t *testing.T) { - name := "Test_parseQuery_CreateCollection" - type testStruct struct { - dbName string - collName string - ifNotExists bool - ru, maxru int - pk string - isLargePk bool - uk [][]string - } - testData := map[string]testStruct{ - "CREATE COLLECTION db1.table1 WITH pk=/id": {dbName: "db1", collName: "table1", ifNotExists: false, ru: 0, maxru: 0, pk: "/id", isLargePk: false, uk: nil}, - "create\ntable\r\ndb-2.table_2 WITH\r\nPK=/email WITH\nru=100": {dbName: "db-2", collName: "table_2", ifNotExists: false, ru: 100, maxru: 0, pk: "/email", isLargePk: false, uk: nil}, - "CREATE collection\nIF\nNOT\t\nEXISTS\r\n\tdb_3.table-3 with largePK=/id WITH\tmaxru=100": {dbName: "db_3", collName: "table-3", ifNotExists: true, ru: 0, maxru: 100, pk: "/id", isLargePk: true, uk: nil}, - "create TABLE if not exists db-0_1.table_0-1 WITH LARGEpk=/a/b/c with uk=/a:/b,/c/d;/e/f/g": {dbName: "db-0_1", collName: "table_0-1", ifNotExists: true, ru: 0, maxru: 0, pk: "/a/b/c", isLargePk: false, uk: [][]string{{"/a"}, {"/b", "/c/d"}, {"/e/f/g"}}}, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtCreateCollection); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtCreateCollection", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.ifNotExists != data.ifNotExists { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ifNotExists, dbstmt.ifNotExists) - } else if dbstmt.ru != data.ru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ru, dbstmt.ru) - } else if dbstmt.maxru != data.maxru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.maxru, dbstmt.maxru) - } else if dbstmt.pk != data.pk { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.pk, dbstmt.pk) - } else if !reflect.DeepEqual(dbstmt.uk, data.uk) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.uk, dbstmt.uk) - } - } - - invalidQueries := []string{ - "CREATE collection db.coll", - "CREATE collection db.coll WITH pk=/a WITH largepk=/b", - "CREATE collection db.coll WITH pk=", - "CREATE collection db.coll WITH largepk=", - "CREATE collection db.coll WITH pk=/id WITH ru=400 WITH maxru=1000", - "create TABLE db.coll WITH pk=/id WITH ru=-1 WITH maxru=1000", - "CREATE COLLECTION db.coll WITH pk=/id WITH ru=400 WITH maxru=-1", - "CREATE TABLE db.table WITH pk=/id WITH ru=-1", - "CREATE COLLECTION db.table WITH pk=/id WITH ru=-1", - "CREATE TABLE db WITH pk=/id", // no collection name - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_CreateCollectionDefaultDb(t *testing.T) { - name := "Test_parseQuery_CreateCollectionDefaultDb" - dbName := "mydb" - type testStruct struct { - dbName string - collName string - ifNotExists bool - ru, maxru int - pk string - isLargePk bool - uk [][]string - } - testData := map[string]testStruct{ - "CREATE COLLECTION table1 WITH pk=/id": {dbName: dbName, collName: "table1", ifNotExists: false, ru: 0, maxru: 0, pk: "/id", isLargePk: false, uk: nil}, - "create\ntable\r\ndb2.table_2 WITH\r\nPK=/email WITH\nru=100": {dbName: "db2", collName: "table_2", ifNotExists: false, ru: 100, maxru: 0, pk: "/email", isLargePk: false, uk: nil}, - "CREATE collection\nIF\nNOT\t\nEXISTS\r\n\ttable-3 with largePK=/id WITH\tmaxru=100": {dbName: dbName, collName: "table-3", ifNotExists: true, ru: 0, maxru: 100, pk: "/id", isLargePk: true, uk: nil}, - "create TABLE if not exists db3.table_0-1 WITH LARGEpk=/a/b/c with uk=/a:/b,/c/d;/e/f/g": {dbName: "db3", collName: "table_0-1", ifNotExists: true, ru: 0, maxru: 0, pk: "/a/b/c", isLargePk: false, uk: [][]string{{"/a"}, {"/b", "/c/d"}, {"/e/f/g"}}}, - } - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtCreateCollection); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtCreateCollection", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.ifNotExists != data.ifNotExists { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ifNotExists, dbstmt.ifNotExists) - } else if dbstmt.ru != data.ru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ru, dbstmt.ru) - } else if dbstmt.maxru != data.maxru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.maxru, dbstmt.maxru) - } else if dbstmt.pk != data.pk { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.pk, dbstmt.pk) - } else if !reflect.DeepEqual(dbstmt.uk, data.uk) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.uk, dbstmt.uk) - } - } - - invalidQueries := []string{ - "CREATE TABLE .mytable WITH pk=/id", - } - for _, query := range invalidQueries { - if _, err := parseQueryWithDefaultDb(nil, dbName, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_AlterCollection(t *testing.T) { - name := "Test_parseQuery_AlterCollection" - type testStruct struct { - dbName string - collName string - ru, maxru int - } - testData := map[string]testStruct{ - "ALTER collection db1.table1 WITH ru=400": {dbName: "db1", collName: "table1", ru: 400, maxru: 0}, - "alter\nTABLE\r\ndb-2.table_2 WITH\r\nmaxru=40000": {dbName: "db-2", collName: "table_2", ru: 0, maxru: 40000}, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtAlterCollection); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtCreateCollection", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.ru != data.ru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ru, dbstmt.ru) - } else if dbstmt.maxru != data.maxru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.maxru, dbstmt.maxru) - } - } - - invalidQueries := []string{ - "ALTER collection db.coll", - "ALTER collection coll WITH ru=400", - "ALTER collection .coll WITH maxru=4000", - "alter TABLE db.coll WITH ru=400 WITH maxru=4000", - "alter TABLE db.coll WITH ru=-1", - "alter TABLE db.coll WITH maxru=-1", - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_AlterCollectionDefaultDb(t *testing.T) { - name := "Test_parseQuery_AlterCollectionDefaultDb" - dbName := "mydb" - type testStruct struct { - dbName string - collName string - ru, maxru int - } - testData := map[string]testStruct{ - "ALTER collection db1.table1 WITH ru=400": {dbName: "db1", collName: "table1", ru: 400, maxru: 0}, - "alter\nTABLE\r\ndb-2.table_2 WITH\r\nmaxru=40000": {dbName: "db-2", collName: "table_2", ru: 0, maxru: 40000}, - "ALTER collection table1 WITH ru=400": {dbName: dbName, collName: "table1", ru: 400, maxru: 0}, - "alter\nTABLE\r\ntable_2 WITH\r\nmaxru=40000": {dbName: dbName, collName: "table_2", ru: 0, maxru: 40000}, - } - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtAlterCollection); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtCreateCollection", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.ru != data.ru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ru, dbstmt.ru) - } else if dbstmt.maxru != data.maxru { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.maxru, dbstmt.maxru) - } - } - - invalidQueries := []string{ - "ALTER COLLECTION .mytable WITH ru=400", - "ALTER COLLECTION WITH ru=400", - } - for _, query := range invalidQueries { - if _, err := parseQueryWithDefaultDb(nil, dbName, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_DropCollection(t *testing.T) { - name := "Test_parseQuery_DropCollection" - type testStruct struct { - dbName string - collName string - ifExists bool - } - testData := map[string]testStruct{ - "DROP COLLECTION db1.table1": {dbName: "db1", collName: "table1", ifExists: false}, - "DROP\t\ntable\r\n\tdb-2.table_2": {dbName: "db-2", collName: "table_2", ifExists: false}, - "drop collection\nIF EXISTS\tdb_3.table-3": {dbName: "db_3", collName: "table-3", ifExists: true}, - "Drop Table If Exists db-4_0.table_4-0": {dbName: "db-4_0", collName: "table_4-0", ifExists: true}, - } - - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtDropCollection); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtDropDatabase", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.ifExists != data.ifExists { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ifExists, dbstmt.ifExists) - } - } - - invalidQueries := []string{ - "DROP collection db", // no collection name - "drop TABLE db", // no collection name - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_DropCollectionDefaultDb(t *testing.T) { - name := "Test_parseQuery_DropCollectionDefaultDb" - dbName := "mydb" - type testStruct struct { - dbName string - collName string - ifExists bool - } - testData := map[string]testStruct{ - "DROP COLLECTION table1": {dbName: dbName, collName: "table1", ifExists: false}, - "DROP\t\ntable\r\n\tdb-2.table_2": {dbName: "db-2", collName: "table_2", ifExists: false}, - "drop collection\nIF EXISTS\ttable-3": {dbName: dbName, collName: "table-3", ifExists: true}, - "Drop Table If Exists db-4_0.table_4-0": {dbName: "db-4_0", collName: "table_4-0", ifExists: true}, - } - - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtDropCollection); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtDropDatabase", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.ifExists != data.ifExists { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.ifExists, dbstmt.ifExists) - } - } - - invalidQueries := []string{ - "DROP collection .mytable", - } - for _, query := range invalidQueries { - if _, err := parseQueryWithDefaultDb(nil, dbName, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - func Test_parseQuery_ListCollections(t *testing.T) { name := "Test_parseQuery_ListCollections" testData := map[string]string{ From aa0e7245b978492125ad3015cb4802c536facfa1 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Sun, 4 Jun 2023 22:16:13 +1000 Subject: [PATCH 02/13] update tests --- stmt.go | 28 +- stmt_collection_parsing_test.go | 84 +++++ stmt_document_parsing_test.go | 614 ++++++++++++++++++++++++++++++++ stmt_test.go | 604 ------------------------------- 4 files changed, 712 insertions(+), 618 deletions(-) diff --git a/stmt.go b/stmt.go index 47b741d..ffe3d79 100644 --- a/stmt.go +++ b/stmt.go @@ -16,20 +16,20 @@ const ( ) var ( - reCreateDb = regexp.MustCompile(`(?im)^CREATE\s+DATABASE` + ifNotExists + `\s+` + field + with + `$`) - reAlterDb = regexp.MustCompile(`(?im)^ALTER\s+DATABASE` + `\s+` + field + with + `$`) - reDropDb = regexp.MustCompile(`(?im)^DROP\s+DATABASE` + ifExists + `\s+` + field + `$`) - reListDbs = regexp.MustCompile(`(?im)^LIST\s+DATABASES?$`) - - reCreateColl = regexp.MustCompile(`(?im)^CREATE\s+(COLLECTION|TABLE)` + ifNotExists + `\s+(` + field + `\.)?` + field + with + `$`) - reAlterColl = regexp.MustCompile(`(?im)^ALTER\s+(COLLECTION|TABLE)` + `\s+(` + field + `\.)?` + field + with + `$`) - reDropColl = regexp.MustCompile(`(?im)^DROP\s+(COLLECTION|TABLE)` + ifExists + `\s+(` + field + `\.)?` + field + `$`) - reListColls = regexp.MustCompile(`(?im)^LIST\s+(COLLECTIONS?|TABLES?)(\s+FROM\s+` + field + `)?$`) - - reInsert = regexp.MustCompile(`(?im)^(INSERT|UPSERT)\s+INTO\s+(` + field + `\.)?` + field + `\s*\(([^)]*?)\)\s*VALUES\s*\(([^)]*?)\)$`) - reSelect = regexp.MustCompile(`(?im)^SELECT\s+(CROSS\s+PARTITION\s+)?.*?\s+FROM\s+` + field + `.*?` + with + `$`) - reUpdate = regexp.MustCompile(`(?im)^UPDATE\s+(` + field + `\.)?` + field + `\s+SET\s+(.*)\s+WHERE\s+id\s*=\s*(.*)$`) - reDelete = regexp.MustCompile(`(?im)^DELETE\s+FROM\s+(` + field + `\.)?` + field + `\s+WHERE\s+id\s*=\s*(.*)$`) + reCreateDb = regexp.MustCompile(`(?is)^CREATE\s+DATABASE` + ifNotExists + `\s+` + field + with + `$`) + reAlterDb = regexp.MustCompile(`(?is)^ALTER\s+DATABASE` + `\s+` + field + with + `$`) + reDropDb = regexp.MustCompile(`(?is)^DROP\s+DATABASE` + ifExists + `\s+` + field + `$`) + reListDbs = regexp.MustCompile(`(?is)^LIST\s+DATABASES?$`) + + reCreateColl = regexp.MustCompile(`(?is)^CREATE\s+(COLLECTION|TABLE)` + ifNotExists + `\s+(` + field + `\.)?` + field + with + `$`) + reAlterColl = regexp.MustCompile(`(?is)^ALTER\s+(COLLECTION|TABLE)` + `\s+(` + field + `\.)?` + field + with + `$`) + reDropColl = regexp.MustCompile(`(?is)^DROP\s+(COLLECTION|TABLE)` + ifExists + `\s+(` + field + `\.)?` + field + `$`) + reListColls = regexp.MustCompile(`(?is)^LIST\s+(COLLECTIONS?|TABLES?)(\s+FROM\s+` + field + `)?$`) + + reInsert = regexp.MustCompile(`(?is)^(INSERT|UPSERT)\s+INTO\s+(` + field + `\.)?` + field + `\s*\(([^)]*?)\)\s*VALUES\s*\(([^)]*?)\)$`) + reSelect = regexp.MustCompile(`(?is)^SELECT\s+(CROSS\s+PARTITION\s+)?.*?\s+FROM\s+` + field + `.*?` + with + `$`) + reUpdate = regexp.MustCompile(`(?is)^UPDATE\s+(` + field + `\.)?` + field + `\s+SET\s+(.*)\s+WHERE\s+id\s*=\s*(.*)$`) + reDelete = regexp.MustCompile(`(?is)^DELETE\s+FROM\s+(` + field + `\.)?` + field + `\s+WHERE\s+id\s*=\s*(.*)$`) ) func parseQuery(c *Conn, query string) (driver.Stmt, error) { diff --git a/stmt_collection_parsing_test.go b/stmt_collection_parsing_test.go index 5d711fd..d616226 100644 --- a/stmt_collection_parsing_test.go +++ b/stmt_collection_parsing_test.go @@ -264,3 +264,87 @@ func TestStmtDropCollection_parse_defaultDb(t *testing.T) { }) } } + +func TestStmtListCollections_parse(t *testing.T) { + testName := "TestStmtListCollections_parse" + testData := []struct { + name string + sql string + expected *StmtListCollections + mustError bool + }{ + {name: "error_invalid_query1", sql: "LIST COLLECTIONS", mustError: true}, + {name: "error_invalid_query2", sql: "LIST TABLES", mustError: true}, + {name: "error_invalid_query3", sql: "LIST COLLECTION", mustError: true}, + {name: "error_invalid_query4", sql: "LIST TABLE", mustError: true}, + {name: "error_invalid_query5", sql: "LIST COLLECTIONS FROM", mustError: true}, + {name: "error_invalid_query6", sql: "LIST TABLES FROM", mustError: true}, + {name: "error_invalid_query7", sql: "LIST COLLECTION FROM", mustError: true}, + {name: "error_invalid_query8", sql: "LIST TABLE FROM", mustError: true}, + + {name: "basic", sql: "LIST COLLECTIONS from db1", expected: &StmtListCollections{dbName: "db1"}}, + {name: "collections", sql: "list \n\tcollection \t FROM \t\rdb-2", expected: &StmtListCollections{dbName: "db-2"}}, + {name: "tables", sql: "LIST tables\n\tFROM\t\rdb_3", expected: &StmtListCollections{dbName: "db_3"}}, + {name: "table", sql: "list TABLE from db-4_0", expected: &StmtListCollections{dbName: "db-4_0"}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtListCollections) + if !ok { + t.Fatalf("%s failed: expected StmtListCollections but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtListCollections_parse_defaultDb(t *testing.T) { + testName := "TestStmtListCollections_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtListCollections + mustError bool + }{ + {name: "basic", db: "mydb", sql: "LIST COLLECTIONS", expected: &StmtListCollections{dbName: "mydb"}}, + {name: "db_in_query", db: "mydb", sql: "list\r\tcollection FROM\n db-2", expected: &StmtListCollections{dbName: "db-2"}}, + {name: "tables", db: "mydb", sql: "LIST tables", expected: &StmtListCollections{dbName: "mydb"}}, + {name: "table_db_in_query", db: "mydb", sql: "list TABLE from db-4_0", expected: &StmtListCollections{dbName: "db-4_0"}}, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtListCollections) + if !ok { + t.Fatalf("%s failed: expected StmtListCollections but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} diff --git a/stmt_document_parsing_test.go b/stmt_document_parsing_test.go index c7270a2..f004f77 100644 --- a/stmt_document_parsing_test.go +++ b/stmt_document_parsing_test.go @@ -1 +1,615 @@ package gocosmos + +import ( + "reflect" + "testing" +) + +func TestStmtInsert_parse(t *testing.T) { + testName := "TestStmtInsert_parse" + testData := []struct { + name string + sql string + expected *StmtInsert + mustError bool + }{ + {name: "error_no_collection", sql: `INSERT INTO db (a,b,c) VALUES (1,2,3)`, mustError: true}, + {name: "error_values", sql: `INSERT INTO db.table (a,b,c)`, mustError: true}, + {name: "error_columns", sql: `INSERT INTO db.table VALUES (1,2,3)`, mustError: true}, + {name: "error_invalid_string", sql: `INSERT INTO db.table (a) VALUES ('a string')`, mustError: true}, + {name: "error_invalid_string2", sql: `INSERT INTO db.table (a) VALUES ("a string")`, mustError: true}, + {name: "error_invalid_string3", sql: `INSERT INTO db.table (a) VALUES ("{key:value}")`, mustError: true}, + {name: "error_num_values_not_matched", sql: `INSERT INTO db.table (a,b) VALUES (1,2,3)`, mustError: true}, + {name: "error_invalid_number", sql: `INSERT INTO db.table (a,b) VALUES (0x1qa,2)`, mustError: true}, + {name: "error_invalid_string", sql: `INSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, mustError: true}, + + { + name: "basic", + sql: `INSERT INTO +db1.table1 (a, b, c, d, e, +f) VALUES + (null, 1.0, +true, "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, + expected: &StmtInsert{dbName: "db1", collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + }, + { + name: "with_placeholders", + sql: `INSERT +INTO db-2.table_2 ( +a,b,c) VALUES ( +$1, :3, @2)`, + expected: &StmtInsert{dbName: "db-2", collName: "table_2", fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtInsert) + if !ok { + t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.fieldsStr = "" + stmt.valuesStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtInsert_parse_defaultDb(t *testing.T) { + testName := "TestStmtInsert_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtInsert + mustError bool + }{ + {name: "error_invalid_query", sql: `INSERT INTO .table (a,b) VALUES (1,2)`, mustError: true}, + {name: "error_invalid_query2", sql: `INSERT INTO db. (a,b) VALUES (1,2)`, mustError: true}, + + { + name: "basic", + db: "mydb", + sql: `INSERT INTO +table1 (a, b, c, d, e, +f) VALUES + (null, 1.0, +true, "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, + expected: &StmtInsert{dbName: "mydb", collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + }, + { + name: "with_placeholders_table_in_query", + db: "mydb", + sql: `INSERT +INTO db-2.table_2 ( +a,b,c) VALUES ( +$1, :3, @2)`, + expected: &StmtInsert{dbName: "db-2", collName: "table_2", fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtInsert) + if !ok { + t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.fieldsStr = "" + stmt.valuesStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtUpsert_parse(t *testing.T) { + testName := "TestStmtUpsert_parse" + testData := []struct { + name string + sql string + expected *StmtInsert + mustError bool + }{ + {name: "error_no_collection", sql: `UPSERT INTO db (a,b,c) VALUES (1,2,3)`, mustError: true}, + {name: "error_values", sql: `UPSERT INTO db.table (a,b,c)`, mustError: true}, + {name: "error_columns", sql: `UPSERT INTO db.table VALUES (1,2,3)`, mustError: true}, + {name: "error_invalid_string", sql: `UPSERT INTO db.table (a) VALUES ('a string')`, mustError: true}, + {name: "error_invalid_string2", sql: `UPSERT INTO db.table (a) VALUES ("a string")`, mustError: true}, + {name: "error_invalid_string3", sql: `UPSERT INTO db.table (a) VALUES ("{key:value}")`, mustError: true}, + {name: "error_num_values_not_matched", sql: `UPSERT INTO db.table (a,b) VALUES (1,2,3)`, mustError: true}, + {name: "error_invalid_number", sql: `UPSERT INTO db.table (a,b) VALUES (0x1qa,2)`, mustError: true}, + {name: "error_invalid_string", sql: `UPSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, mustError: true}, + + { + name: "basic", + sql: `UPSERT INTO +db1.table1 (a, +b, c, d, e, +f) VALUES + (null, 1.0, true, + "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, + expected: &StmtInsert{dbName: "db1", collName: "table1", isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + }, + { + name: "with_placeholders", + sql: `UPSERT +INTO db-2.table_2 ( +a,b,c) VALUES ($1, + :3, @2)`, + expected: &StmtInsert{dbName: "db-2", collName: "table_2", isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtInsert) + if !ok { + t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.fieldsStr = "" + stmt.valuesStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtUpsert_parse_defaultDb(t *testing.T) { + testName := "TestStmtUpsert_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtInsert + mustError bool + }{ + {name: "error_invalid_query", sql: `UPSERT INTO .table (a,b) VALUES (1,2)`, mustError: true}, + {name: "error_invalid_query2", sql: `UPSERT INTO db. (a,b) VALUES (1,2)`, mustError: true}, + + { + name: "basic", + db: "mydb", + sql: `UPSERT INTO +table1 (a, +b, c, d, e, +f) VALUES + (null, 1.0, true, + "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, + expected: &StmtInsert{dbName: "mydb", collName: "table1", isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + }, + { + name: "with_placeholders_table_in_query", + db: "mydb", + sql: `UPSERT +INTO db-2.table_2 ( +a,b,c) VALUES ($1, + :3, @2)`, + expected: &StmtInsert{dbName: "db-2", collName: "table_2", isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtInsert) + if !ok { + t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.fieldsStr = "" + stmt.valuesStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtDelete_parse(t *testing.T) { + testName := "TestStmtDelete_parse" + testData := []struct { + name string + sql string + expected *StmtDelete + mustError bool + }{ + {name: "error_no_collection", sql: `DELETE FROM db WHERE id=1`, mustError: true}, + {name: "error_where", sql: `DELETE FROM db.table`, mustError: true}, + {name: "error_empty_id", sql: `DELETE FROM db.table WHERE id=`, mustError: true}, + {name: "error_invalid_value", sql: `DELETE FROM db.table WHERE id="1`, mustError: true}, + {name: "error_invalid_value2", sql: `DELETE FROM db.table WHERE id=2"`, mustError: true}, + {name: "error_invalid_where", sql: `DELETE FROM db.table WHERE id=@1 a`, mustError: true}, + {name: "error_invalid_where2", sql: `DELETE FROM db.table WHERE id=b $2`, mustError: true}, + {name: "error_invalid_where3", sql: `DELETE FROM db.table WHERE id=c :3 d`, mustError: true}, + + { + name: "basic", + sql: `DELETE FROM +db1.table1 WHERE + id=abc`, + expected: &StmtDelete{dbName: "db1", collName: "table1", idStr: "abc"}, + }, + { + name: "basic2", + sql: ` + DELETE +FROM db-2.table_2 + WHERE id="def"`, + expected: &StmtDelete{dbName: "db-2", collName: "table_2", idStr: "def"}, + }, + { + name: "basic3", + sql: `DELETE FROM +db_3-0.table-3_0 WHERE + id=@2`, + expected: &StmtDelete{dbName: "db_3-0", collName: "table-3_0", idStr: "@2", id: placeholder{2}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtDelete) + if !ok { + t.Fatalf("%s failed: expected StmtDelete but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtDelete_parse_defaultDb(t *testing.T) { + testName := "TestStmtDelete_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtDelete + mustError bool + }{ + {name: "error_invalid_query", sql: `DELETE FROM .table WHERE id=1`, mustError: true}, + {name: "error_invalid_query2", sql: `DELETE FROM db. WHERE id=1`, mustError: true}, + + { + name: "basic", + db: "mydb", + sql: `DELETE FROM +table1 WHERE + id=abc`, + expected: &StmtDelete{dbName: "mydb", collName: "table1", idStr: "abc"}, + }, + { + name: "db_in_query", + db: "mydb", + sql: ` + DELETE +FROM db-2.table_2 + WHERE id="def"`, + expected: &StmtDelete{dbName: "db-2", collName: "table_2", idStr: "def"}, + }, + { + name: "placeholder", + db: "mydb", + sql: `DELETE FROM +db_3-0.table-3_0 WHERE + id=@2`, + expected: &StmtDelete{dbName: "db_3-0", collName: "table-3_0", idStr: "@2", id: placeholder{2}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtDelete) + if !ok { + t.Fatalf("%s failed: expected StmtDelete but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtSelect_parse(t *testing.T) { + testName := "TestStmtSelect_parse" + testData := []struct { + name string + sql string + expected *StmtSelect + mustError bool + }{ + {name: "error_db_and_collection", sql: `SELECT * FROM db.table`, mustError: true}, + {name: "error_no_collection", sql: `SELECT * WITH db=dbname`, mustError: true}, + {name: "error_no_db", sql: `SELECT * FROM c WITH collection=collname`, mustError: true}, + {name: "error_cross_partition_must_be_true", sql: `SELECT * FROM c WITH db=dbname WITH collection=collname WITH cross_partition=false`, mustError: true}, + + { + name: "basic", + sql: `SELECT * FROM c WITH database=db WITH collection=tbl`, + expected: &StmtSelect{dbName: "db", collName: "tbl", selectQuery: `SELECT * FROM c`, placeholders: map[int]string{}}, + }, + { + name: "cross_partition", + sql: `SELECT CROSS PARTITION * FROM c WHERE id="1" WITH db=db-1 WITH table=tbl_1`, + expected: &StmtSelect{dbName: "db-1", collName: "tbl_1", isCrossPartition: true, selectQuery: `SELECT * FROM c WHERE id="1"`, placeholders: map[int]string{}}, + }, + { + name: "placeholders", + sql: `SELECT id,username,email FROM c WHERE username!=@1 AND (id>:2 OR email=$3) WITH CROSS_PARTITION=true WITH database=db_3-0 WITH table=table-3_0`, + expected: &StmtSelect{dbName: "db_3-0", collName: "table-3_0", isCrossPartition: true, selectQuery: `SELECT id,username,email FROM c WHERE username!=@_1 AND (id>@_2 OR email=@_3)`, placeholders: map[int]string{1: "@_1", 2: "@_2", 3: "@_3"}}, + }, + { + name: "collection_in_query", + sql: `SELECT a,b,c FROM user u WHERE u.id="1" WITH db=dbtemp`, + expected: &StmtSelect{dbName: "dbtemp", collName: "user", selectQuery: `SELECT a,b,c FROM user u WHERE u.id="1"`, placeholders: map[int]string{}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtSelect) + if !ok { + t.Fatalf("%s failed: expected StmtSelect but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtSelect_parse_defaultDb(t *testing.T) { + testName := "TestStmtSelect_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtSelect + mustError bool + }{ + { + name: "basic", + db: "mydb", + sql: `SELECT * FROM c WITH collection=tbl`, + expected: &StmtSelect{dbName: "mydb", collName: "tbl", selectQuery: `SELECT * FROM c`, placeholders: map[int]string{}}, + }, + { + name: "db_table_in_query", + db: "mydb", + sql: `SELECT CROSS PARTITION * FROM c WHERE id="1" WITH db=db-1 WITH table=tbl_1`, + expected: &StmtSelect{dbName: "db-1", collName: "tbl_1", isCrossPartition: true, selectQuery: `SELECT * FROM c WHERE id="1"`, placeholders: map[int]string{}}, + }, + { + name: "placeholders", + db: "mydb", + sql: `SELECT id,username,email FROM c WHERE username!=@1 AND (id>:2 OR email=$3) WITH CROSS_PARTITION=true WITH table=tbl_2-0`, + expected: &StmtSelect{dbName: "mydb", collName: "tbl_2-0", isCrossPartition: true, selectQuery: `SELECT id,username,email FROM c WHERE username!=@_1 AND (id>@_2 OR email=@_3)`, placeholders: map[int]string{1: "@_1", 2: "@_2", 3: "@_3"}}, + }, + { + name: "collection_in_query", + db: "mydb", + sql: `SELECT a,b,c FROM user u WHERE u.id="1"`, + expected: &StmtSelect{dbName: "mydb", collName: "user", selectQuery: `SELECT a,b,c FROM user u WHERE u.id="1"`, placeholders: map[int]string{}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtSelect) + if !ok { + t.Fatalf("%s failed: expected StmtSelect but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtUpdate_parse(t *testing.T) { + testName := "TestStmtUpdate_parse" + testData := []struct { + name string + sql string + expected *StmtUpdate + mustError bool + }{ + {name: "error_no_collection", sql: `UPDATE db SET a=1,b=2,c=3 WHERE id=4`, mustError: true}, + {name: "error_where", sql: `UPDATE db.table SET a=1,b=2,c=3 WHERE username=4`, mustError: true}, + {name: "error_no_where", sql: `UPDATE db.table SET a=1,b=2,c=3`, mustError: true}, + {name: "error_no_set", sql: `UPDATE db.table WHERE id=1`, mustError: true}, + {name: "error_empty_set", sql: `UPDATE db.table SET WHERE id=1`, mustError: true}, + {name: "error_invalid_value", sql: `UPDATE db.table SET a="{key:value}" WHERE id=1`, mustError: true}, + {name: "error_invalid_query", sql: `UPDATE db.table SET =1 WHERE id=2`, mustError: true}, + {name: "error_invalid_query2", sql: `UPDATE db.table SET a=1 WHERE id= `, mustError: true}, + {name: "error_invalid_query3", sql: `UPDATE db.table SET a=1,b=2,c=3 WHERE id="4`, mustError: true}, + + { + name: "basic", + sql: `UPDATE db1.table1 +SET a=null, b= + 1.0, c=true, + d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" +,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]" WHERE + id="abc"`, + expected: &StmtUpdate{dbName: "db1", collName: "table1", updateStr: `a=null, b= + 1.0, c=true, + d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" +,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]"`, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}, idStr: "abc"}, + }, + { + name: "basic2", + sql: `UPDATE db-1.table_1 +SET a=$1, b= + $2, c=:3, d=0 WHERE + id=@4`, + expected: &StmtUpdate{dbName: "db-1", collName: "table_1", updateStr: `a=$1, b= + $2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtUpdate) + if !ok { + t.Fatalf("%s failed: expected StmtUpdate but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtUpdate_parse_defaultDb(t *testing.T) { + testName := "TestStmtUpdate_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtUpdate + mustError bool + }{ + {name: "error_invalid_query", sql: `UPDATE .table SET a=1,b=2,c=3 WHERE id=4`, mustError: true}, + {name: "error_invalid_query2", sql: `UPDATE db. SET a=1,b=2,c=3 WHERE id=4`, mustError: true}, + + { + name: "basic", + db: "mydb", + sql: `UPDATE table1 +SET a=null, b= + 1.0, c=true, + d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" +,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]" WHERE + id="abc"`, + expected: &StmtUpdate{dbName: "mydb", collName: "table1", updateStr: `a=null, b= + 1.0, c=true, + d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" +,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]"`, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}, idStr: "abc"}}, + { + name: "db_in_query", + db: "mydb", + sql: `UPDATE db-1.table_1 +SET a=$1, b= + $2, c=:3, d=0 WHERE + id=@4`, + expected: &StmtUpdate{dbName: "db-1", collName: "table_1", updateStr: `a=$1, b= + $2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtUpdate) + if !ok { + t.Fatalf("%s failed: expected StmtUpdate but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} diff --git a/stmt_test.go b/stmt_test.go index 274c87c..3c14fad 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -1,7 +1,6 @@ package gocosmos import ( - "reflect" "testing" ) @@ -35,606 +34,3 @@ func TestStmt_NumInput(t *testing.T) { } } } - -/*----------------------------------------------------------------------*/ - -func Test_parseQuery_ListCollections(t *testing.T) { - name := "Test_parseQuery_ListCollections" - testData := map[string]string{ - "LIST COLLECTIONS from db1": "db1", - "list\n\tcollection FROM\r\n db-2": "db-2", - "LIST tables\r\n\tFROM\tdb_3": "db_3", - "list TABLE from db-4_0": "db-4_0", - } - - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtListCollections); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtListDatabases", name+"/"+query) - } else if dbstmt.dbName != data { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data, dbstmt.dbName) - } - } - - invalidQueries := []string{ - "LIST COLLECTIONS", - "LIST TABLES", - "LIST COLLECTION", - "LIST TABLE", - "LIST COLLECTIONS FROM", - "LIST TABLES FROM", - "LIST COLLECTION FROM", - "LIST TABLE FROM", - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_ListCollectionsDefaultDb(t *testing.T) { - name := "Test_parseQuery_ListCollectionsDefaultDb" - dbName := "mydb" - testData := map[string]string{ - "LIST COLLECTIONS": dbName, - "list\n\tcollection FROM\r\n db-2": "db-2", - "LIST tables": dbName, - "list TABLE from db-4_0": "db-4_0", - } - - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtListCollections); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtListDatabases", name+"/"+query) - } else if dbstmt.dbName != data { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data, dbstmt.dbName) - } - } -} - -func Test_parseQuery_Insert(t *testing.T) { - name := "Test_parseQuery_Insert" - type testStruct struct { - dbName string - collName string - fields []string - values []interface{} - } - testData := map[string]testStruct{ - `INSERT INTO -db1.table1 (a, b, c, d, e, -f) VALUES - (null, 1.0, -true, "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`: { - dbName: "db1", collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{ - nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}, - }, - }, - `INSERT -INTO db-2.table_2 ( -a,b,c) VALUES ( -$1, :3, @2)`: { - dbName: "db-2", collName: "table_2", fields: []string{"a", "b", "c"}, values: []interface{}{ - placeholder{1}, placeholder{3}, placeholder{2}, - }, - }, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtInsert); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtInsert", name+"/"+query) - } else if dbstmt.isUpsert { - t.Fatalf("%s failed: is-upsert must be disabled", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if !reflect.DeepEqual(dbstmt.fields, data.fields) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.fields, dbstmt.fields) - } else if !reflect.DeepEqual(dbstmt.values, data.values) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.values, dbstmt.values) - } - } - - invalidQueries := []string{ - `INSERT INTO db (a,b,c) VALUES (1,2,3)`, // no collection name - `INSERT INTO db.table (a,b,c)`, // no VALUES part - `INSERT INTO db.table VALUES (1,2,3)`, // no column list - `INSERT INTO db.table (a) VALUES ('a string')`, // invalid string literature - `INSERT INTO db.table (a) VALUES ("a string")`, // should be "\"a string\"" - `INSERT INTO db.table (a) VALUES ("{key:value}")`, // should be "{\"key\:\"value\"}" - `INSERT INTO db.table (a,b) VALUES (1,2,3)`, // number of field and value mismatch - `INSERT INTO db.table (a,b) VALUES (0x1qa,2)`, // invalid number - `INSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, // invalid string - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_InsertDefaultDb(t *testing.T) { - name := "Test_parseQuery_InsertDefaultDb" - dbName := "mydb" - type testStruct struct { - dbName string - collName string - fields []string - values []interface{} - } - testData := map[string]testStruct{ - `INSERT INTO -table1 (a, b, c, d, e, -f) VALUES - (null, 1.0, -true, "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`: { - dbName: dbName, collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{ - nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}, - }, - }, - `INSERT -INTO db-2.table_2 ( -a,b,c) VALUES ( -$1, :3, @2)`: { - dbName: "db-2", collName: "table_2", fields: []string{"a", "b", "c"}, values: []interface{}{ - placeholder{1}, placeholder{3}, placeholder{2}, - }, - }, - } - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtInsert); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtInsert", name+"/"+query) - } else if dbstmt.isUpsert { - t.Fatalf("%s failed: is-upsert must be disabled", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if !reflect.DeepEqual(dbstmt.fields, data.fields) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.fields, dbstmt.fields) - } else if !reflect.DeepEqual(dbstmt.values, data.values) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.values, dbstmt.values) - } - } - - invalidQueries := []string{ - `INSERT INTO .table (a,b) VALUES (1,2)`, - } - for _, query := range invalidQueries { - if _, err := parseQueryWithDefaultDb(nil, dbName, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_Upsert(t *testing.T) { - name := "Test_parseQuery_Upsert" - type testStruct struct { - dbName string - collName string - fields []string - values []interface{} - } - testData := map[string]testStruct{ - `UPSERT INTO -db1.table1 (a, -b, c, d, e, -f) VALUES - (null, 1.0, true, - "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`: { - dbName: "db1", collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{ - nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}, - }, - }, - `UPSERT -INTO db-2.table_2 ( -a,b,c) VALUES ($1, - :3, @2)`: { - dbName: "db-2", collName: "table_2", fields: []string{"a", "b", "c"}, values: []interface{}{ - placeholder{1}, placeholder{3}, placeholder{2}, - }, - }, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtInsert); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtInsert", name+"/"+query) - } else if !dbstmt.isUpsert { - t.Fatalf("%s failed: is-upsert must be enabled", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if !reflect.DeepEqual(dbstmt.fields, data.fields) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.fields, dbstmt.fields) - } else if !reflect.DeepEqual(dbstmt.values, data.values) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.values, dbstmt.values) - } - } - - invalidQueries := []string{ - `UPSERT INTO db (a,b,c) VALUES (1,2,3)`, // no collection name - `UPSERT INTO db.table (a,b,c)`, // no VALUES part - `UPSERT INTO db.table VALUES (1,2,3)`, // no column list - `UPSERT INTO db.table (a) VALUES ('a string')`, // invalid string literature - `UPSERT INTO db.table (a) VALUES ("a string")`, // should be "\"a string\"" - `UPSERT INTO db.table (a) VALUES ("{key:value}")`, // should be "{\"key\:\"value\"}" - `UPSERT INTO db.table (a,b) VALUES (1,2,3)`, // number of field and value mismatch - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_UpsertDefaultDb(t *testing.T) { - name := "Test_parseQuery_UpsertDefaultDb" - dbName := "mydb" - type testStruct struct { - dbName string - collName string - fields []string - values []interface{} - } - testData := map[string]testStruct{ - `UPSERT INTO -table1 (a, -b, c, d, e, -f) VALUES - (null, 1.0, true, - "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`: { - dbName: dbName, collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{ - nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}, - }, - }, - `UPSERT -INTO db-2.table_2 ( -a,b,c) VALUES ($1, - :3, @2)`: { - dbName: "db-2", collName: "table_2", fields: []string{"a", "b", "c"}, values: []interface{}{ - placeholder{1}, placeholder{3}, placeholder{2}, - }, - }, - } - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtInsert); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtInsert", name+"/"+query) - } else if !dbstmt.isUpsert { - t.Fatalf("%s failed: is-upsert must be enabled", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if !reflect.DeepEqual(dbstmt.fields, data.fields) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.fields, dbstmt.fields) - } else if !reflect.DeepEqual(dbstmt.values, data.values) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.values, dbstmt.values) - } - } - - invalidQueries := []string{ - `UPSERT INTO .table (a,b,c) VALUES (1,2,3)`, - } - for _, query := range invalidQueries { - if _, err := parseQueryWithDefaultDb(nil, dbName, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_Delete(t *testing.T) { - name := "Test_parseQuery_Delete" - type testStruct struct { - dbName string - collName string - idStr string - id interface{} - } - testData := map[string]testStruct{ - `DELETE FROM -db1.table1 WHERE - id=abc`: {dbName: "db1", collName: "table1", idStr: "abc", id: nil}, - ` - DELETE -FROM db-2.table_2 - WHERE id="def"`: {dbName: "db-2", collName: "table_2", idStr: "def", id: nil}, - `DELETE FROM -db_3-0.table-3_0 WHERE - id=@2`: {dbName: "db_3-0", collName: "table-3_0", idStr: "@2", id: placeholder{2}}, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtDelete); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtDelete", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.idStr != data.idStr { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.idStr, dbstmt.idStr) - } else if !reflect.DeepEqual(dbstmt.id, data.id) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.id, dbstmt.id) - } - } - - invalidQueries := []string{ - `DELETE FROM db WHERE id=1`, // no collection name - `DELETE FROM db.table`, // no WHERE part - `DELETE FROM db.table WHERE id=`, // id is empty - `DELETE FROM db.table WHERE id="1`, - `DELETE FROM db.table WHERE id=2"`, - `DELETE FROM db.table WHERE id=@1 a`, - `DELETE FROM db.table WHERE id=b $2`, - `DELETE FROM db.table WHERE id=c :3 d`, - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_DeleteDefaultDb(t *testing.T) { - name := "Test_parseQuery_DeleteDefaultDb" - dbName := "mydb" - type testStruct struct { - dbName string - collName string - idStr string - id interface{} - } - testData := map[string]testStruct{ - `DELETE FROM -table1 WHERE - id=abc`: {dbName: dbName, collName: "table1", idStr: "abc", id: nil}, - ` - DELETE -FROM db-2.table_2 - WHERE id="def"`: {dbName: "db-2", collName: "table_2", idStr: "def", id: nil}, - `DELETE FROM -db_3-0.table-3_0 WHERE - id=@2`: {dbName: "db_3-0", collName: "table-3_0", idStr: "@2", id: placeholder{2}}, - } - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtDelete); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtDelete", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.idStr != data.idStr { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.idStr, dbstmt.idStr) - } else if !reflect.DeepEqual(dbstmt.id, data.id) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.id, dbstmt.id) - } - } - - invalidQueries := []string{ - `DELETE FROM .table WHERE id=1`, // no collection name - } - for _, query := range invalidQueries { - if _, err := parseQueryWithDefaultDb(nil, dbName, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_Select(t *testing.T) { - name := "Test_parseQuery_Select" - type testStruct struct { - dbName string - collName string - isCrossPartition bool - selectQuery string - } - testData := map[string]testStruct{ - `SELECT * FROM c WITH database=db WITH collection=tbl`: { - dbName: "db", collName: "tbl", isCrossPartition: false, selectQuery: `SELECT * FROM c`}, - `SELECT CROSS PARTITION * FROM c WHERE id="1" WITH db=db-1 WITH table=tbl_1`: { - dbName: "db-1", collName: "tbl_1", isCrossPartition: true, selectQuery: `SELECT * FROM c WHERE id="1"`}, - `SELECT id,username,email FROM c WHERE username!=@1 AND (id>:2 OR email=$3) WITH CROSS_PARTITION=true WITH database=db WITH table=tbl`: { - dbName: "db", collName: "tbl", isCrossPartition: true, selectQuery: `SELECT id,username,email FROM c WHERE username!=@_1 AND (id>@_2 OR email=@_3)`}, - `SELECT a,b,c FROM user u WHERE u.id="1" WITH db=dbtemp`: { - dbName: "dbtemp", collName: "user", isCrossPartition: false, selectQuery: `SELECT a,b,c FROM user u WHERE u.id="1"`}, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtSelect); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtSelect", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.isCrossPartition != data.isCrossPartition { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.isCrossPartition, dbstmt.isCrossPartition) - } else if dbstmt.selectQuery != data.selectQuery { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.selectQuery, dbstmt.selectQuery) - } - } - - invalidQueries := []string{ - `SELECT * FROM db.table`, // database and collection must be specified by WITH database= and WITH collection= - `SELECT * WITH db=dbname`, // no collection - `SELECT * FROM c WITH collection=collname`, // no database - `SELECT * FROM c WITH db=dbname WITH collection=collname WITH cross_partition=false`, // the only valid value for cross_partition is true - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_SelectDefaultDb(t *testing.T) { - name := "Test_parseQuery_SelectDefaultDb" - dbName := "mydb" - type testStruct struct { - dbName string - collName string - isCrossPartition bool - selectQuery string - } - testData := map[string]testStruct{ - `SELECT * FROM c WITH collection=tbl`: { - dbName: dbName, collName: "tbl", isCrossPartition: false, selectQuery: `SELECT * FROM c`}, - `SELECT CROSS PARTITION * FROM c WHERE id="1" WITH db=db-1 WITH table=tbl_1`: { - dbName: "db-1", collName: "tbl_1", isCrossPartition: true, selectQuery: `SELECT * FROM c WHERE id="1"`}, - `SELECT id,username,email FROM c WHERE username!=@1 AND (id>:2 OR email=$3) WITH CROSS_PARTITION=true WITH table=tbl`: { - dbName: dbName, collName: "tbl", isCrossPartition: true, selectQuery: `SELECT id,username,email FROM c WHERE username!=@_1 AND (id>@_2 OR email=@_3)`}, - `SELECT a,b,c FROM user u WHERE u.id="1"`: { - dbName: dbName, collName: "user", isCrossPartition: false, selectQuery: `SELECT a,b,c FROM user u WHERE u.id="1"`}, - } - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtSelect); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtSelect", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.isCrossPartition != data.isCrossPartition { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.isCrossPartition, dbstmt.isCrossPartition) - } else if dbstmt.selectQuery != data.selectQuery { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.selectQuery, dbstmt.selectQuery) - } - } -} - -func Test_parseQuery_Update(t *testing.T) { - name := "Test_parseQuery_Update" - type testStruct struct { - dbName string - collName string - idStr string - id interface{} - fields []string - values []interface{} - } - testData := map[string]testStruct{ - `UPDATE db1.table1 -SET a=null, b= - 1.0, c=true, - d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" -,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]" WHERE - id="abc"`: { - dbName: "db1", collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{ - nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}, - }, idStr: "abc", id: nil}, - `UPDATE db-1.table_1 -SET a=$1, b= - $2, c=:3, d=0 WHERE - id=@4`: { - dbName: "db-1", collName: "table_1", fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, - idStr: "@4", id: placeholder{4}}, - } - for query, data := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtUpdate); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtUpdate", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.idStr != data.idStr { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.idStr, dbstmt.idStr) - } else if dbstmt.id != data.id { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.id, dbstmt.id) - } else if !reflect.DeepEqual(dbstmt.fields, data.fields) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.fields, dbstmt.fields) - } else if !reflect.DeepEqual(dbstmt.values, data.values) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.values, dbstmt.values) - } - } - - invalidQueries := []string{ - `UPDATE db SET a=1,b=2,c=3 WHERE id=4`, // no collection name - `UPDATE db.table SET a=1,b=2,c=3 WHERE username=4`, // only WHERE id... is accepted - `UPDATE db.table SET a=1,b=2,c=3`, // no WHERE clause - `UPDATE db.table WHERE id=1`, // no SET clause - `UPDATE db.table SET WHERE id=1`, // SET clause is empty - `UPDATE db.table SET a="{key:value}" WHERE id=1`, // should be "{\"key\:\"value\"}" - `UPDATE db.table SET =1 WHERE id=2`, // invalid SET clause - `UPDATE db.table SET a=1 WHERE id= `, // empty id - `UPDATE db.table SET a=1,b=2,c=3 WHERE id="4`, // invalid id literate - } - for _, query := range invalidQueries { - if _, err := parseQuery(nil, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} - -func Test_parseQuery_UpdateDefaultDb(t *testing.T) { - name := "Test_parseQuery_UpdateDefaultDb" - dbName := "mydb" - type testStruct struct { - dbName string - collName string - idStr string - id interface{} - fields []string - values []interface{} - } - testData := map[string]testStruct{ - `UPDATE table1 -SET a=null, b= - 1.0, c=true, - d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" -,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]" WHERE - id="abc"`: { - dbName: dbName, collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{ - nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}, - }, idStr: "abc", id: nil}, - `UPDATE db-1.table_1 -SET a=$1, b= - $2, c=:3, d=0 WHERE - id=@4`: { - dbName: "db-1", collName: "table_1", fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, - idStr: "@4", id: placeholder{4}}, - } - for query, data := range testData { - if stmt, err := parseQueryWithDefaultDb(nil, dbName, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if dbstmt, ok := stmt.(*StmtUpdate); !ok { - t.Fatalf("%s failed: the parsed stmt must be of type *StmtUpdate", name+"/"+query) - } else if dbstmt.dbName != data.dbName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.dbName, dbstmt.dbName) - } else if dbstmt.collName != data.collName { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.collName, dbstmt.collName) - } else if dbstmt.idStr != data.idStr { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.idStr, dbstmt.idStr) - } else if dbstmt.id != data.id { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.id, dbstmt.id) - } else if !reflect.DeepEqual(dbstmt.fields, data.fields) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.fields, dbstmt.fields) - } else if !reflect.DeepEqual(dbstmt.values, data.values) { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, data.values, dbstmt.values) - } - } - - invalidQueries := []string{ - `UPDATE .table SET a=1,b=2,c=3 WHERE id=4`, - } - for _, query := range invalidQueries { - if _, err := parseQueryWithDefaultDb(nil, dbName, query); err == nil { - t.Fatalf("%s failed: query must not be parsed/validated successfully", name+"/"+query) - } - } -} From 7699ba40b3b584017352f17ca2010313329ef75e Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Mon, 5 Jun 2023 01:07:09 +1000 Subject: [PATCH 03/13] refactor StmtDatabase and tests --- .github/workflows/gocosmos.yaml | 41 ++++-- driver.go | 15 ++ gocosmos.go | 25 +++- gocosmos_test.go | 157 -------------------- stmt.go | 148 ++++++++++++++++++- stmt_database.go | 94 ++---------- stmt_database_test.go | 253 ++++++++++++++++++++++++++++++++ 7 files changed, 476 insertions(+), 257 deletions(-) create mode 100644 stmt_database_test.go diff --git a/.github/workflows/gocosmos.yaml b/.github/workflows/gocosmos.yaml index 38897bc..63bf91c 100644 --- a/.github/workflows/gocosmos.yaml +++ b/.github/workflows/gocosmos.yaml @@ -7,8 +7,27 @@ on: branches: [ main ] jobs: - testDriver: - name: Test driver + testDriverQueryParsing: + name: Test driver query parsing + runs-on: ubuntu-latest + steps: + - name: Set up Go env + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + - name: Test + run: | + go test -cover -coverprofile="coverage_driver_parse.txt" -v -count 1 -p 1 -run "_parse" . + - name: Codecov + uses: codecov/codecov-action@v3 + with: + flags: driver_parse + name: driver_parse + + testDriverStmtDatabase: + name: Test driver database statements runs-on: windows-latest steps: - name: Set up Go env @@ -26,12 +45,12 @@ jobs: netstat -nt $env:COSMOSDB_DRIVER='gocosmos' $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' - go test -cover -coverprofile="coverage_driver.txt" -v -count 1 -p 1 -run "Test_" . + go test -cover -coverprofile="coverage_driver_database.txt" -v -count 1 -p 1 -run "TestStmt.*Database_(Exec|Query)" . - name: Codecov uses: codecov/codecov-action@v3 with: - flags: driver - name: driver + flags: driver_database + name: driver_database testDriverSelect: name: Test driver SELECT query @@ -113,7 +132,7 @@ jobs: testOther: name: Test other - runs-on: windows-latest + runs-on: ubuntu-latest steps: - name: Set up Go env uses: actions/setup-go@v2 @@ -123,14 +142,8 @@ jobs: uses: actions/checkout@v2 - name: Test run: | - choco install azure-cosmosdb-emulator - & "C:\Program Files\Azure Cosmos DB Emulator\Microsoft.Azure.Cosmos.Emulator.exe" /DisableRateLimiting /NoUI /NoExplorer - Start-Sleep -s 60 - try { Invoke-RestMethod -Method GET https://127.0.0.1:8081/ } catch {} - netstat -nt - $env:COSMOSDB_DRIVER='gocosmos' - $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' - go test -cover -coverprofile="coverage_other.txt" -v -count 1 -p 1 -run "TestNew" . + export COSMOSDB_URL="AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + go test -cover -coverprofile="coverage_other.txt" -v -count 1 -p 1 -run "TestNew|TestStmt_NumInput|TestDriver_" . - name: Codecov uses: codecov/codecov-action@v3 with: diff --git a/driver.go b/driver.go index a38b06e..942137d 100644 --- a/driver.go +++ b/driver.go @@ -65,6 +65,21 @@ var ( // ErrConflict is returned when the executing operation cause conflict (e.g. duplicated id). ErrConflict = errors.New("StatusCode=409 Conflict") + + // ErrOperationNotSupported is returned to indicate that the operation is not supported. + // + // @Available since v0.3.0 + ErrOperationNotSupported = errors.New("this operation is not supported") + + // ErrExecNotSupported is returned to indicate that the Exec/ExecContext operation is not supported. + // + // @Available since v0.3.0 + ErrExecNotSupported = errors.New("this operation is not supported, please use Query") + + // ErrQueryNotSupported is returned to indicate that the Query/QueryContext operation is not supported. + // + // @Available since v0.3.0 + ErrQueryNotSupported = errors.New("this operation is not supported, please use Exec") ) // Driver is Azure Cosmos DB implementation of driver.Driver. diff --git a/gocosmos.go b/gocosmos.go index d7658ec..0fbc91f 100644 --- a/gocosmos.go +++ b/gocosmos.go @@ -1,7 +1,30 @@ // Package gocosmos provides database/sql driver and a REST API client for Azure Cosmos DB SQL API. package gocosmos +import ( + "reflect" +) + const ( // Version of package gocosmos. - Version = "0.2.0" + Version = "0.3.0" ) + +func goTypeToCosmosDbType(typ reflect.Type) string { + if typ == nil { + return "" + } + switch typ.Kind() { + case reflect.Bool: + return "BOOLEAN" + case reflect.String: + return "STRING" + case reflect.Float32, reflect.Float64: + return "NUMBER" + case reflect.Array, reflect.Slice: + return "ARRAY" + case reflect.Map: + return "MAP" + } + return "" +} diff --git a/gocosmos_test.go b/gocosmos_test.go index f296bca..d415802 100644 --- a/gocosmos_test.go +++ b/gocosmos_test.go @@ -129,163 +129,6 @@ func TestDriver_Close(t *testing.T) { /*----------------------------------------------------------------------*/ -func Test_Query_CreateDatabase(t *testing.T) { - name := "Test_Query_CreateDatabase" - db := _openDb(t, name) - _, err := db.Query("CREATE DATABASE dbtemp") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_CreateDatabase(t *testing.T) { - name := "Test_Query_CreateDatabase" - db := _openDb(t, name) - dbname := "dbtemp" - defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) - - // clean up - db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) - - // first creation should be successful - result, err := db.Exec(fmt.Sprintf("CREATE DATABASE %s WITH ru=400", dbname)) - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - // second creation should return ErrConflict - _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s WITH ru=400", dbname)) - if err != ErrConflict { - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } - - // clean up - db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) - - // first creation should be successful - result, err = db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s WITH maxru=4000", dbname)) - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - // second creation should be successful with "IF NOT EXISTS" - result, err = db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s WITH maxru=4000", dbname)) - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 0 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=0/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } -} - -func Test_Query_AlterDatabase(t *testing.T) { - name := "Test_Query_AlterDatabase" - db := _openDb(t, name) - _, err := db.Query("ALTER DATABASE dbtemp WITH ru=400") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_AlterDatabase(t *testing.T) { - name := "Test_Exec_AlterDatabase" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_found") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp WITH ru=400") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - - result, err := db.Exec("ALTER DATABASE dbtemp WITH ru=500") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - result, err = db.Exec("ALTER DATABASE dbtemp WITH maxru=6000") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - _, err = db.Exec("ALTER DATABASE db_not_found WITH maxru=6000") - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %s", name, err) - } -} - -func Test_Exec_AlterDatabaseNoOffer(t *testing.T) { - name := "Test_Exec_AlterDatabaseNoOffer" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - - _, err := db.Exec("ALTER DATABASE dbtemp WITH maxru=6000") - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %s", name, err) - } -} - -func Test_Query_DropDatabase(t *testing.T) { - name := "Test_Query_DropDatabase" - db := _openDb(t, name) - _, err := db.Query("DROP DATABASE dbtemp") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_DropDatabase(t *testing.T) { - name := "Test_Exec_DropDatabase" - db := _openDb(t, name) - - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - - // first drop should be successful - _, err := db.Exec("DROP DATABASE dbtemp") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - // second drop should return ErrNotFound - _, err = db.Exec("DROP DATABASE dbtemp") - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - - // third drop should be successful with "IF EXISTS" - _, err = db.Exec("DROP DATABASE IF EXISTS dbtemp") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } -} - func Test_Exec_ListDatabases(t *testing.T) { name := "Test_Exec_ListDatabases" db := _openDb(t, name) diff --git a/stmt.go b/stmt.go index ffe3d79..b807b12 100644 --- a/stmt.go +++ b/stmt.go @@ -3,7 +3,10 @@ package gocosmos import ( "database/sql/driver" "fmt" + "io" + "reflect" "regexp" + "sort" "strings" ) @@ -227,17 +230,148 @@ func (s *Stmt) parseWithOpts(withOptsStr string) error { return nil } -// // validateWithOpts is no-op in this struct. Sub-implementations may override this behavior. -// func (s *Stmt) validateWithOpts() error { -// return nil -// } - -// Close implements driver.Stmt.Close. +// Close implements driver.Stmt/Close. func (s *Stmt) Close() error { return nil } -// NumInput implements driver.Stmt.NumInput. +// NumInput implements driver.Stmt/NumInput. func (s *Stmt) NumInput() int { return s.numInput } + +/*----------------------------------------------------------------------*/ + +func buildResultNoResultSet(restResponse *RestReponse, supportLastInsertId bool, rid string, ignoreErrorCode int) *ResultNoResultSet { + result := &ResultNoResultSet{ + err: restResponse.Error(), + lastInsertId: rid, + supportLastInsertId: supportLastInsertId, + } + if result.err == nil { + result.affectedRows = 1 + } + switch restResponse.StatusCode { + case 403: + if ignoreErrorCode == 403 { + result.err = nil + } else { + result.err = ErrForbidden + } + case 404: + if ignoreErrorCode == 404 { + result.err = nil + } else { + result.err = ErrNotFound + } + case 409: + if ignoreErrorCode == 409 { + result.err = nil + } else { + result.err = ErrConflict + } + } + return result +} + +// ResultNoResultSet captures the result from statements that do not expect a ResultSet to be returned. +// +// @Available since v0.3.0 +type ResultNoResultSet struct { + err error + affectedRows int64 + supportLastInsertId bool + lastInsertId string // holds the "_rid" if the operation returns it +} + +// LastInsertId implements driver.Result/LastInsertId. +func (r *ResultNoResultSet) LastInsertId() (int64, error) { + if r.err != nil { + return 0, r.err + } + if !r.supportLastInsertId { + return 0, ErrOperationNotSupported + } + return 0, fmt.Errorf(`{"last_insert_id":"%s"}`, r.lastInsertId) +} + +// RowsAffected implements driver.Result/RowsAffected. +func (r *ResultNoResultSet) RowsAffected() (int64, error) { + return r.affectedRows, r.err +} + +/*----------------------------------------------------------------------*/ + +// ResultResultSet captures the result from statements that expect a ResultSet to be returned. +// +// @Available since v0.3.0 +type ResultResultSet struct { + err error + count int + cursorCount int + columnList []string + columnTypes map[string]reflect.Type + rowData []map[string]interface{} +} + +func (r *ResultResultSet) init() *ResultResultSet { + if r.rowData == nil { + return r + } + if r.columnTypes == nil { + r.columnTypes = make(map[string]reflect.Type) + } + r.count = len(r.rowData) + colMap := make(map[string]bool) + for _, item := range r.rowData { + for col, val := range item { + colMap[col] = true + if r.columnTypes[col] == nil { + r.columnTypes[col] = reflect.TypeOf(val) + } + } + } + r.columnList = make([]string, 0, len(colMap)) + for col := range colMap { + r.columnList = append(r.columnList, col) + } + sort.Strings(r.columnList) + + return r +} + +// Columns implements driver.Rows/Columns. +func (r *ResultResultSet) Columns() []string { + return r.columnList +} + +// ColumnTypeScanType implements driver.RowsColumnTypeScanType/ColumnTypeScanType +func (r *ResultResultSet) ColumnTypeScanType(index int) reflect.Type { + return r.columnTypes[r.columnList[index]] +} + +// ColumnTypeDatabaseTypeName implements driver.RowsColumnTypeDatabaseTypeName/ColumnTypeDatabaseTypeName +func (r *ResultResultSet) ColumnTypeDatabaseTypeName(index int) string { + return goTypeToCosmosDbType(r.columnTypes[r.columnList[index]]) +} + +// Close implements driver.Rows/Close. +func (r *ResultResultSet) Close() error { + return r.err +} + +// Next implements driver.Rows/Next. +func (r *ResultResultSet) Next(dest []driver.Value) error { + if r.err != nil { + return r.err + } + if r.cursorCount >= r.count { + return io.EOF + } + rowData := r.rowData[r.cursorCount] + r.cursorCount++ + for i, colName := range r.columnList { + dest[i] = rowData[colName] + } + return nil +} diff --git a/stmt_database.go b/stmt_database.go index 6bdd231..bc92f89 100644 --- a/stmt_database.go +++ b/stmt_database.go @@ -56,47 +56,19 @@ func (s *StmtCreateDatabase) validate() error { // Query implements driver.Stmt.Query. // This function is not implemented, use Exec instead. func (s *StmtCreateDatabase) Query(_ []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") + return nil, ErrQueryNotSupported } // Exec implements driver.Stmt.Exec. // Upon successful call, this function return (*ResultCreateDatabase, nil). func (s *StmtCreateDatabase) Exec(_ []driver.Value) (driver.Result, error) { restResult := s.conn.restClient.CreateDatabase(DatabaseSpec{Id: s.dbName, Ru: s.ru, MaxRu: s.maxru}) - result := &ResultCreateDatabase{Successful: restResult.Error() == nil, InsertId: restResult.Rid} - err := restResult.Error() - switch restResult.StatusCode { - case 403: - err = ErrForbidden - case 409: - if s.ifNotExists { - err = nil - } else { - err = ErrConflict - } - } - return result, err -} - -// ResultCreateDatabase captures the result from CREATE DATABASE operation. -type ResultCreateDatabase struct { - // Successful flags if the operation was successful or not. - Successful bool - // InsertId holds the "_rid" if the operation was successful. - InsertId string -} - -// LastInsertId implements driver.Result.LastInsertId. -func (r *ResultCreateDatabase) LastInsertId() (int64, error) { - return 0, fmt.Errorf("this operation is not supported. {LastInsertId:%s}", r.InsertId) -} - -// RowsAffected implements driver.Result.RowsAffected. -func (r *ResultCreateDatabase) RowsAffected() (int64, error) { - if r.Successful { - return 1, nil + ignoreErrorCode := 0 + if s.ifNotExists { + ignoreErrorCode = 409 } - return 0, nil + result := buildResultNoResultSet(&restResult.RestReponse, true, restResult.Rid, ignoreErrorCode) + return result, result.err } /*----------------------------------------------------------------------*/ @@ -148,7 +120,7 @@ func (s *StmtAlterDatabase) validate() error { // Query implements driver.Stmt.Query. // This function is not implemented, use Exec instead. func (s *StmtAlterDatabase) Query(_ []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") + return nil, ErrQueryNotSupported } // Exec implements driver.Stmt.Exec. @@ -165,36 +137,8 @@ func (s *StmtAlterDatabase) Exec(_ []driver.Value) (driver.Result, error) { return nil, err } restResult := s.conn.restClient.ReplaceOfferForResource(getResult.Rid, s.ru, s.maxru) - result := &ResultAlterDatabase{Successful: restResult.Error() == nil} - err := restResult.Error() - switch restResult.StatusCode { - case 403: - err = ErrForbidden - case 404: - err = ErrNotFound - } - return result, err -} - -// ResultAlterDatabase captures the result from ALTER DATABASE operation. -// -// Available since v0.1.1 -type ResultAlterDatabase struct { - // Successful flags if the operation was successful or not. - Successful bool -} - -// LastInsertId implements driver.Result.LastInsertId. -func (r *ResultAlterDatabase) LastInsertId() (int64, error) { - return 0, fmt.Errorf("this operation is not supported") -} - -// RowsAffected implements driver.Result.RowsAffected. -func (r *ResultAlterDatabase) RowsAffected() (int64, error) { - if r.Successful { - return 1, nil - } - return 0, nil + result := buildResultNoResultSet(&restResult.RestReponse, true, restResult.Rid, 0) + return result, result.err } /*----------------------------------------------------------------------*/ @@ -219,25 +163,19 @@ func (s *StmtDropDatabase) validate() error { // Query implements driver.Stmt.Query. // This function is not implemented, use Exec instead. func (s *StmtDropDatabase) Query(_ []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") + return nil, ErrQueryNotSupported } // Exec implements driver.Stmt.Exec. // This function always return a nil driver.Result. func (s *StmtDropDatabase) Exec(_ []driver.Value) (driver.Result, error) { restResult := s.conn.restClient.DeleteDatabase(s.dbName) - err := restResult.Error() - switch restResult.StatusCode { - case 403: - err = ErrForbidden - case 404: - if s.ifExists { - err = nil - } else { - err = ErrNotFound - } + ignoreErrorCode := 0 + if s.ifExists { + ignoreErrorCode = 404 } - return nil, err + result := buildResultNoResultSet(&restResult.RestReponse, false, "", ignoreErrorCode) + return result, result.err } /*----------------------------------------------------------------------*/ @@ -258,7 +196,7 @@ func (s *StmtListDatabases) validate() error { // Exec implements driver.Stmt.Exec. // This function is not implemented, use Query instead. func (s *StmtListDatabases) Exec(_ []driver.Value) (driver.Result, error) { - return nil, errors.New("this operation is not supported, please use query") + return nil, ErrExecNotSupported } // Query implements driver.Stmt.Query. diff --git a/stmt_database_test.go b/stmt_database_test.go new file mode 100644 index 0000000..64e2321 --- /dev/null +++ b/stmt_database_test.go @@ -0,0 +1,253 @@ +package gocosmos + +import ( + "encoding/json" + "fmt" + "testing" +) + +func TestStmtCreateDatabase_Query(t *testing.T) { + testName := "TestStmtCreateDatabase_Query" + db := _openDb(t, testName) + _, err := db.Query("CREATE DATABASE dbtemp") + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrOperationNotSupported, but received %#v", testName, err) + } +} + +func TestStmtCreateDatabase_Exec(t *testing.T) { + testName := "TestStmtCreateDatabase_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustConflict bool + affectedRows int64 + }{ + { + name: "create_new", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)}, + sql: fmt.Sprintf("CREATE DATABASE %s WITH ru=400", dbname), + affectedRows: 1, + }, + { + name: "create_conflict", + sql: fmt.Sprintf("CREATE DATABASE %s WITH ru=400", dbname), + mustConflict: true, + }, + { + name: "create_new2", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)}, + sql: fmt.Sprintf("CREATE DATABASE %s WITH ru=400", dbname), + affectedRows: 1, + }, + { + name: "create_if_not_exists", + sql: fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s WITH ru=400", dbname), + affectedRows: 0, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + +func TestStmtAlterDatabase_Query(t *testing.T) { + testName := "TestStmtAlterDatabase_Query" + db := _openDb(t, testName) + _, err := db.Query("ALTER DATABASE dbtemp WITH ru=400") + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrOperationNotSupported, but received %#v", testName, err) + } +} + +func TestStmtAlterDatabase_Exec(t *testing.T) { + testName := "TestStmtAlterDatabase_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustNotFound bool + affectedRows int64 + }{ + { + name: "change_ru", + initSqls: []string{"DROP DATABASE IF EXISTS db_not_found", fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s WITH ru=400", dbname)}, + sql: fmt.Sprintf("ALTER DATABASE %s WITH ru=500", dbname), + affectedRows: 1, + }, + { + name: "change_maxru", + sql: fmt.Sprintf("ALTER DATABASE %s WITH maxru=6000", dbname), + affectedRows: 1, + }, + { + name: "db_not_found", + sql: "ALTER DATABASE db_not_found WITH maxru=6000", + mustNotFound: true, + }, + { + name: "db_no_offer", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)}, + sql: fmt.Sprintf("ALTER DATABASE %s WITH maxru=6000", dbname), + mustNotFound: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + +func TestStmtDropDatabase_Query(t *testing.T) { + testName := "TestStmtDropDatabase_Query" + db := _openDb(t, testName) + _, err := db.Query("DROP DATABASE dbtemp") + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrOperationNotSupported, but received %#v", testName, err) + } +} + +func TestStmtDropDatabase_Exec(t *testing.T) { + testName := "TestStmtDropDatabase_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustNotFound bool + affectedRows int64 + }{ + { + name: "basic", + initSqls: []string{fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)}, + sql: fmt.Sprintf("DROP DATABASE %s", dbname), + affectedRows: 1, + }, + { + name: "not_found", + sql: fmt.Sprintf("DROP DATABASE %s", dbname), + mustNotFound: true, + }, + { + name: "basic2", + initSqls: []string{fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)}, + sql: fmt.Sprintf("DROP DATABASE %s", dbname), + affectedRows: 1, + }, + { + name: "if_exists", + sql: fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err != ErrOperationNotSupported { + t.Fatalf("%s failed: expected ErrOperationNotSupported but received %#v", testName+"/"+testCase.name, err) + } + }) + } +} From c01915610574ab20ecaad61a8435e333345d3326 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Mon, 5 Jun 2023 12:55:15 +1000 Subject: [PATCH 04/13] update test --- gocosmos_test.go | 63 ------------------------------- restclient.go | 12 ++++++ stmt_database.go | 87 ++++++++++++++++++++++--------------------- stmt_database_test.go | 52 ++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 106 deletions(-) diff --git a/gocosmos_test.go b/gocosmos_test.go index d415802..8bc1cc8 100644 --- a/gocosmos_test.go +++ b/gocosmos_test.go @@ -129,69 +129,6 @@ func TestDriver_Close(t *testing.T) { /*----------------------------------------------------------------------*/ -func Test_Exec_ListDatabases(t *testing.T) { - name := "Test_Exec_ListDatabases" - db := _openDb(t, name) - _, err := db.Exec("LIST DATABASES") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Query_ListDatabases(t *testing.T) { - name := "Test_Query_ListDatabases" - db := _openDb(t, name) - - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp1") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp2") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp1") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp2") - - dbRows, err := db.Query("LIST DATABASES") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - result := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["id"]) - result[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - _, ok1 := result["dbtemp"] - _, ok2 := result["dbtemp1"] - _, ok3 := result["dbtemp2"] - if !ok1 { - t.Fatalf("%s failed: database %s not found", name, "dbtemp") - } - if !ok2 { - t.Fatalf("%s failed: database %s not found", name, "dbtemp1") - } - if !ok3 { - t.Fatalf("%s failed: database %s not found", name, "dbtemp2") - } -} - -/*----------------------------------------------------------------------*/ - func Test_Query_CreateCollection(t *testing.T) { name := "Test_Query_CreateCollection" db := _openDb(t, name) diff --git a/restclient.go b/restclient.go index efcccab..514121c 100644 --- a/restclient.go +++ b/restclient.go @@ -1158,6 +1158,18 @@ type DbInfo struct { Users string `json:"_users"` // (system-generated property) _users attribute of the database } +func (db *DbInfo) toMap() map[string]interface{} { + return map[string]interface{}{ + "id": db.Id, + "_rid": db.Rid, + "_ts": db.Ts, + "_self": db.Self, + "_etag": db.Etag, + "_colls": db.Colls, + "_users": db.Users, + } +} + // RespCreateDb captures the response from RestClient.CreateDatabase call. type RespCreateDb struct { RestReponse diff --git a/stmt_database.go b/stmt_database.go index bc92f89..8d1fd3a 100644 --- a/stmt_database.go +++ b/stmt_database.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "errors" "fmt" - "io" "strconv" ) @@ -202,52 +201,54 @@ func (s *StmtListDatabases) Exec(_ []driver.Value) (driver.Result, error) { // Query implements driver.Stmt.Query. func (s *StmtListDatabases) Query(_ []driver.Value) (driver.Rows, error) { restResult := s.conn.restClient.ListDatabases() - err := restResult.Error() - var rows driver.Rows - if err == nil { - rows = &RowsListDatabases{ - count: int(restResult.Count), - databases: restResult.Databases, - cursorCount: 0, + result := &ResultResultSet{ + err: restResult.Error(), + columnList: []string{"id", "_rid", "_ts", "_self", "_etag", "_colls", "_users"}, + } + if result.err == nil { + result.count = len(restResult.Databases) + result.rowData = make([]map[string]interface{}, result.count) + for i, db := range restResult.Databases { + result.rowData[i] = db.toMap() } } switch restResult.StatusCode { case 403: - err = ErrForbidden + result.err = ErrForbidden } - return rows, err -} - -// RowsListDatabases captures the result from LIST DATABASES operation. -type RowsListDatabases struct { - count int - databases []DbInfo - cursorCount int -} - -// Columns implements driver.Rows.Columns. -func (r *RowsListDatabases) Columns() []string { - return []string{"id", "_rid", "_ts", "_self", "_etag", "_colls", "_users"} -} - -// Close implements driver.Rows.Close. -func (r *RowsListDatabases) Close() error { - return nil + return result, result.err } -// Next implements driver.Rows.Next. -func (r *RowsListDatabases) Next(dest []driver.Value) error { - if r.cursorCount >= r.count { - return io.EOF - } - rowData := r.databases[r.cursorCount] - r.cursorCount++ - dest[0] = rowData.Id - dest[1] = rowData.Rid - dest[2] = rowData.Ts - dest[3] = rowData.Self - dest[4] = rowData.Etag - dest[5] = rowData.Colls - dest[6] = rowData.Users - return nil -} +// // RowsListDatabases captures the result from LIST DATABASES operation. +// type RowsListDatabases struct { +// count int +// databases []DbInfo +// cursorCount int +// } +// +// // Columns implements driver.Rows.Columns. +// func (r *RowsListDatabases) Columns() []string { +// return []string{"id", "_rid", "_ts", "_self", "_etag", "_colls", "_users"} +// } +// +// // Close implements driver.Rows.Close. +// func (r *RowsListDatabases) Close() error { +// return nil +// } +// +// // Next implements driver.Rows.Next. +// func (r *RowsListDatabases) Next(dest []driver.Value) error { +// if r.cursorCount >= r.count { +// return io.EOF +// } +// rowData := r.databases[r.cursorCount] +// r.cursorCount++ +// dest[0] = rowData.Id +// dest[1] = rowData.Rid +// dest[2] = rowData.Ts +// dest[3] = rowData.Self +// dest[4] = rowData.Etag +// dest[5] = rowData.Colls +// dest[6] = rowData.Users +// return nil +// } diff --git a/stmt_database_test.go b/stmt_database_test.go index 64e2321..5efd277 100644 --- a/stmt_database_test.go +++ b/stmt_database_test.go @@ -251,3 +251,55 @@ func TestStmtDropDatabase_Exec(t *testing.T) { }) } } + +func TestStmtListDatabases_Exec(t *testing.T) { + testName := "TestStmtListDatabases_Exec" + db := _openDb(t, testName) + _, err := db.Query("LIST DATABASES") + if err != ErrExecNotSupported { + t.Fatalf("%s failed: expected ErrOperationNotSupported, but received %#v", testName, err) + } +} + +func TestStmtListDatabases_Query(t *testing.T) { + testName := "TestStmtListDatabases_Query" + db := _openDb(t, testName) + dbnames := []string{"dbtemp", "dbtemp2", "dbtemp1"} + for _, dbname := range dbnames { + db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)) + } + defer func() { + for _, dbname := range dbnames { + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + } + }() + dbRows, err := db.Query("LIST DATABASES") + if err != nil { + t.Fatalf("%s failed: %s", testName+"/query", err) + } + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/fetch_rows", err) + } + ok0, ok1, ok2 := false, false, false + for _, row := range rows { + if row["id"] == "dbtemp" { + ok0 = true + } + if row["id"] == "dbtemp1" { + ok1 = true + } + if row["id"] == "dbtemp2" { + ok2 = true + } + } + if !ok0 { + t.Fatalf("%s failed: database %s not found", testName, "dbtemp") + } + if !ok1 { + t.Fatalf("%s failed: database %s not found", testName, "dbtemp1") + } + if !ok2 { + t.Fatalf("%s failed: database %s not found", testName, "dbtemp2") + } +} From b3d7889cc8283bb9d37d3d10d6a1a7f3d2c5642f Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Mon, 5 Jun 2023 15:16:18 +1000 Subject: [PATCH 05/13] refactor & update tests --- .github/workflows/gocosmos.yaml | 28 +- gocosmos_test.go | 308 ----------------- restclient.go | 19 + stmt.go | 2 +- stmt_collection.go | 189 +++------- stmt_collection_test.go | 592 ++++++++++++++++++++++++++++++++ stmt_database.go | 61 +--- stmt_database_test.go | 10 +- 8 files changed, 700 insertions(+), 509 deletions(-) create mode 100644 stmt_collection_test.go diff --git a/.github/workflows/gocosmos.yaml b/.github/workflows/gocosmos.yaml index 63bf91c..18581d4 100644 --- a/.github/workflows/gocosmos.yaml +++ b/.github/workflows/gocosmos.yaml @@ -45,13 +45,39 @@ jobs: netstat -nt $env:COSMOSDB_DRIVER='gocosmos' $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' - go test -cover -coverprofile="coverage_driver_database.txt" -v -count 1 -p 1 -run "TestStmt.*Database_(Exec|Query)" . + go test -cover -coverprofile="coverage_driver_database.txt" -v -count 1 -p 1 -run "TestStmt.*Databases?_(Exec|Query)" . - name: Codecov uses: codecov/codecov-action@v3 with: flags: driver_database name: driver_database + testDriverStmtCollection: + name: Test driver collection statements + runs-on: windows-latest + steps: + - name: Set up Go env + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + - name: Test + run: | + choco install azure-cosmosdb-emulator + & "C:\Program Files\Azure Cosmos DB Emulator\Microsoft.Azure.Cosmos.Emulator.exe" /DisableRateLimiting /NoUI /NoExplorer + Start-Sleep -s 60 + try { Invoke-RestMethod -Method GET https://127.0.0.1:8081/ } catch {} + netstat -nt + $env:COSMOSDB_DRIVER='gocosmos' + $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' + go test -cover -coverprofile="coverage_driver_collection.txt" -v -count 1 -p 1 -run "TestStmt.*Collections?_(Exec|Query)" . + - name: Codecov + uses: codecov/codecov-action@v3 + with: + flags: driver_collection + name: driver_collection + testDriverSelect: name: Test driver SELECT query runs-on: windows-latest diff --git a/gocosmos_test.go b/gocosmos_test.go index 8bc1cc8..ad222b4 100644 --- a/gocosmos_test.go +++ b/gocosmos_test.go @@ -3,7 +3,6 @@ package gocosmos import ( "context" "database/sql" - "fmt" "os" "regexp" "strings" @@ -129,313 +128,6 @@ func TestDriver_Close(t *testing.T) { /*----------------------------------------------------------------------*/ -func Test_Query_CreateCollection(t *testing.T) { - name := "Test_Query_CreateCollection" - db := _openDb(t, name) - _, err := db.Query("CREATE COLLECTION dbtemp.tbltemp WITH pk=/id") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_CreateCollection(t *testing.T) { - name := "Test_Exec_CreateCollection" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - - // first creation should be successful - result, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/id WITH ru=400") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - // second creation should return ErrConflict - _, err = db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/id WITH ru=400") - if err != ErrConflict { - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } - - // third creation should be successful with "IF NOT EXISTS" - result, err = db.Exec("CREATE TABLE IF NOT EXISTS dbtemp.tbltemp WITH largepk=/a/b/c WITH maxru=4000 WITH uk=/a;/b,/c/d") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 0 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=0/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - result, err = db.Exec("CREATE TABLE IF NOT EXISTS dbtemp.tbltemp1 WITH largepk=/a/b/c WITH maxru=4000 WITH uk=/a;/b,/c/d") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - _, err = db.Exec(`CREATE COLLECTION db_not_exists.table WITH pk=/a`) - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } -} - -func Test_Exec_CreateCollectionDefaultDb(t *testing.T) { - name := "Test_Exec_CreateCollectionDefaultDb" - dbName := "mydefaultdb" - db := _openDefaultDb(t, name, dbName) - - db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)) - db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbName)) - defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)) - - // first creation should be successful - result, err := db.Exec("CREATE COLLECTION tbltemp WITH pk=/id WITH ru=400") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - // second creation should return ErrConflict - _, err = db.Exec("CREATE COLLECTION tbltemp WITH pk=/id WITH ru=400") - if err != ErrConflict { - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } - - // third creation should be successful with "IF NOT EXISTS" - result, err = db.Exec("CREATE TABLE IF NOT EXISTS tbltemp WITH largepk=/a/b/c WITH maxru=4000 WITH uk=/a;/b,/c/d") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 0 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=0/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } -} - -func Test_Query_AlterCollection(t *testing.T) { - name := "Test_Query_AlterCollection" - db := _openDb(t, name) - _, err := db.Query("ALTER COLLECTION dbtemp.tbltemp WITH ru=400") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_AlterCollection(t *testing.T) { - name := "Test_Exec_AlterCollection" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/id") - - result, err := db.Exec("ALTER COLLECTION dbtemp.tbltemp WITH ru=500") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - result, err = db.Exec("ALTER TABLE dbtemp.tbltemp WITH maxru=6000") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - _, err = db.Exec(`ALTER COLLECTION dbtemp.tbl_not_found WITH ru=400`) - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - - _, err = db.Exec(`ALTER COLLECTION db_not_exists.table WITH ru=400`) - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } -} - -func Test_Exec_AlterCollectionDefaultDb(t *testing.T) { - name := "Test_Exec_AlterCollectionDefaultDb" - dbName := "mydefaultdb" - db := _openDefaultDb(t, name, dbName) - - db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)) - db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbName)) - defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)) - db.Exec("CREATE COLLECTION tbltemp WITH pk=/id") - - result, err := db.Exec("ALTER COLLECTION tbltemp WITH ru=500") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - result, err = db.Exec("ALTER TABLE tbltemp WITH maxru=6000") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } - if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - _, err = db.Exec(`ALTER COLLECTION tbl_not_found WITH ru=400`) - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - - _, err = db.Exec(`ALTER COLLECTION db_not_exists.table WITH ru=400`) - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } -} - -func Test_Query_DropCollection(t *testing.T) { - name := "Test_Query_DropCollection" - db := _openDb(t, name) - _, err := db.Query("DROP COLLECTION dbtemp.tbltemp") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_DropCollection(t *testing.T) { - name := "Test_Exec_DropCollection" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE COLLECTION IF NOT EXISTS dbtemp.tbltemp WITH pk=/id") - - // first drop should be successful - _, err := db.Exec("DROP COLLECTION dbtemp.tbltemp") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - // second drop should return ErrNotFound - _, err = db.Exec("DROP COLLECTION dbtemp.tbltemp") - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - - // third drop should be successful with "IF EXISTS" - _, err = db.Exec("DROP TABLE IF EXISTS dbtemp.tbltemp") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } -} - -func Test_Exec_ListCollections(t *testing.T) { - name := "Test_Exec_ListCollections" - db := _openDb(t, name) - _, err := db.Exec("LIST COLLECTIONS FROM dbtemp") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Query_ListCollections(t *testing.T) { - name := "Test_Query_ListCollections" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/a") - db.Exec("CREATE TABLE dbtemp.tbltemp2 WITH pk=/b") - db.Exec("CREATE COLLECTION dbtemp.tbltemp1 WITH pk=/c") - - dbRows, err := db.Query("LIST COLLECTIONS FROM dbtemp") - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - result := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["id"]) - result[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - _, ok1 := result["tbltemp"] - _, ok2 := result["tbltemp1"] - _, ok3 := result["tbltemp2"] - if !ok1 { - t.Fatalf("%s failed: collection %s not found", name, "dbtemp.tbltemp") - } - if !ok2 { - t.Fatalf("%s failed: collection %s not found", name, "dbtemp.tbltemp1") - } - if !ok3 { - t.Fatalf("%s failed: collection %s not found", name, "dbtemp.tbltemp2") - } - - _, err = db.Query("LIST COLLECTIONS FROM db_not_found") - if err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } -} - func Test_Query_Insert(t *testing.T) { name := "Test_Query_Insert" db := _openDb(t, name) diff --git a/restclient.go b/restclient.go index 514121c..109fdc3 100644 --- a/restclient.go +++ b/restclient.go @@ -1212,6 +1212,25 @@ type CollInfo struct { GeospatialConfig map[string]interface{} `json:"geospatialConfig"` // Geo-spatial configuration settings for collection } +func (c *CollInfo) toMap() map[string]interface{} { + return map[string]interface{}{ + "id": c.Id, + "_rid": c.Rid, + "_ts": c.Ts, + "_self": c.Self, + "_etag": c.Etag, + "_docs": c.Docs, + "_sprocs": c.Sprocs, + "_triggers": c.Triggers, + "_udfs": c.Udfs, + "_conflicts": c.Conflicts, + "indexingPolicy": c.IndexingPolicy, + "partitionKey": c.PartitionKey, + "conflictResolutionPolicy": c.ConflictResolutionPolicy, + "geospatialConfig": c.GeospatialConfig, + } +} + // RespCreateColl captures the response from RestClient.CreateCollection call. type RespCreateColl struct { RestReponse diff --git a/stmt.go b/stmt.go index b807b12..cf49ff3 100644 --- a/stmt.go +++ b/stmt.go @@ -209,7 +209,7 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err return nil, fmt.Errorf("invalid query: %s", query) } -// Stmt is Azure CosmosDB prepared statement handle. +// Stmt is Azure Cosmos DB abstract implementation of driver.Stmt. type Stmt struct { query string // the SQL query conn *Conn // the connection that this prepared statement is bound to diff --git a/stmt_collection.go b/stmt_collection.go index 5e732be..f56f287 100644 --- a/stmt_collection.go +++ b/stmt_collection.go @@ -4,16 +4,18 @@ import ( "database/sql/driver" "errors" "fmt" - "io" "regexp" "strconv" ) -// StmtCreateCollection implements "CREATE COLLECTION" operation. +// StmtCreateCollection implements "CREATE COLLECTION" statement. // // Syntax: // -// CREATE COLLECTION|TABLE [IF NOT EXISTS] [.] [WITH RU|MAXRU=ru] [WITH UK=/path1:/path2,/path3;/path4] +// CREATE COLLECTION|TABLE [IF NOT EXISTS] [.] +// +// [[,] WITH RU|MAXRU=ru] +// [[,] WITH UK=/path1:/path2,/path3;/path4] // // - ru: an integer specifying CosmosDB's collection throughput expressed in RU/s. Supply either RU or MAXRU, not both! // @@ -94,14 +96,13 @@ func (s *StmtCreateCollection) validate() error { return nil } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. func (s *StmtCreateCollection) Query(_ []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") + return nil, ErrQueryNotSupported } -// Exec implements driver.Stmt.Exec. -// Upon successful call, this function returns (*ResultCreateCollection, nil). +// Exec implements driver.Stmt/Exec. func (s *StmtCreateCollection) Exec(_ []driver.Value) (driver.Result, error) { spec := CollectionSpec{DbName: s.dbName, CollName: s.collName, Ru: s.ru, MaxRu: s.maxru, PartitionKeyInfo: map[string]interface{}{ @@ -120,47 +121,17 @@ func (s *StmtCreateCollection) Exec(_ []driver.Value) (driver.Result, error) { } restResult := s.conn.restClient.CreateCollection(spec) - result := &ResultCreateCollection{Successful: restResult.Error() == nil, InsertId: restResult.Rid} - err := restResult.Error() - switch restResult.StatusCode { - case 403: - err = ErrForbidden - case 404: - err = ErrNotFound - case 409: - if s.ifNotExists { - err = nil - } else { - err = ErrConflict - } - } - return result, err -} - -// ResultCreateCollection captures the result from CREATE COLLECTION operation. -type ResultCreateCollection struct { - // Successful flags if the operation was successful or not. - Successful bool - // InsertId holds the "_rid" if the operation was successful. - InsertId string -} - -// LastInsertId implements driver.Result.LastInsertId. -func (r *ResultCreateCollection) LastInsertId() (int64, error) { - return 0, fmt.Errorf("this operation is not supported. {LastInsertId:%s}", r.InsertId) -} - -// RowsAffected implements driver.Result.RowsAffected. -func (r *ResultCreateCollection) RowsAffected() (int64, error) { - if r.Successful { - return 1, nil + ignoreErrorCode := 0 + if s.ifNotExists { + ignoreErrorCode = 409 } - return 0, nil + result := buildResultNoResultSet(&restResult.RestReponse, true, restResult.Rid, ignoreErrorCode) + return result, result.err } /*----------------------------------------------------------------------*/ -// StmtAlterCollection implements "ALTER COLLECTION" operation. +// StmtAlterCollection implements "ALTER COLLECTION" statement. // // Syntax: // @@ -210,14 +181,13 @@ func (s *StmtAlterCollection) validate() error { return nil } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. func (s *StmtAlterCollection) Query(_ []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") + return nil, ErrQueryNotSupported } -// Exec implements driver.Stmt.Exec. -// Upon successful call, this function returns (*ResultAlterCollection, nil). +// Exec implements driver.Stmt/Exec. func (s *StmtAlterCollection) Exec(_ []driver.Value) (driver.Result, error) { getResult := s.conn.restClient.GetCollection(s.dbName, s.collName) if err := getResult.Error(); err != nil { @@ -230,41 +200,13 @@ func (s *StmtAlterCollection) Exec(_ []driver.Value) (driver.Result, error) { return nil, err } restResult := s.conn.restClient.ReplaceOfferForResource(getResult.Rid, s.ru, s.maxru) - result := &ResultAlterCollection{Successful: restResult.Error() == nil} - err := restResult.Error() - switch restResult.StatusCode { - case 403: - err = ErrForbidden - case 404: - err = ErrNotFound - } - return result, err -} - -// ResultAlterCollection captures the result from ALTER COLLECTION operation. -// -// Available since v0.1.1 -type ResultAlterCollection struct { - // Successful flags if the operation was successful or not. - Successful bool -} - -// LastInsertId implements driver.Result.LastInsertId. -func (r *ResultAlterCollection) LastInsertId() (int64, error) { - return 0, fmt.Errorf("this operation is not supported") -} - -// RowsAffected implements driver.Result.RowsAffected. -func (r *ResultAlterCollection) RowsAffected() (int64, error) { - if r.Successful { - return 1, nil - } - return 0, nil + result := buildResultNoResultSet(&restResult.RestReponse, true, restResult.Rid, 0) + return result, result.err } /*----------------------------------------------------------------------*/ -// StmtDropCollection implements "DROP COLLECTION" operation. +// StmtDropCollection implements "DROP COLLECTION" statement. // // Syntax: // @@ -285,33 +227,26 @@ func (s *StmtDropCollection) validate() error { return nil } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. func (s *StmtDropCollection) Query(_ []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") + return nil, ErrQueryNotSupported } -// Exec implements driver.Stmt.Exec. -// This function always return a nil driver.Result. +// Exec implements driver.Stmt/Exec. func (s *StmtDropCollection) Exec(_ []driver.Value) (driver.Result, error) { restResult := s.conn.restClient.DeleteCollection(s.dbName, s.collName) - err := restResult.Error() - switch restResult.StatusCode { - case 403: - err = ErrForbidden - case 404: - if s.ifExists { - err = nil - } else { - err = ErrNotFound - } + ignoreErrorCode := 0 + if s.ifExists { + ignoreErrorCode = 404 } - return nil, err + result := buildResultNoResultSet(&restResult.RestReponse, false, "", ignoreErrorCode) + return result, result.err } /*----------------------------------------------------------------------*/ -// StmtListCollections implements "LIST DATABASES" operation. +// StmtListCollections implements "LIST DATABASES" statement. // // Syntax: // @@ -328,67 +263,31 @@ func (s *StmtListCollections) validate() error { return nil } -// Exec implements driver.Stmt.Exec. +// Exec implements driver.Stmt/Exec. // This function is not implemented, use Query instead. func (s *StmtListCollections) Exec(_ []driver.Value) (driver.Result, error) { - return nil, errors.New("this operation is not supported, please use query") + return nil, ErrExecNotSupported } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. func (s *StmtListCollections) Query(_ []driver.Value) (driver.Rows, error) { restResult := s.conn.restClient.ListCollections(s.dbName) - err := restResult.Error() - var rows driver.Rows - if err == nil { - rows = &RowsListCollections{ - count: int(restResult.Count), - collections: restResult.Collections, - cursorCount: 0, + result := &ResultResultSet{ + err: restResult.Error(), + columnList: []string{"id", "indexingPolicy", "_rid", "_ts", "_self", "_etag", "_docs", "_sprocs", "_triggers", "_udfs", "_conflicts"}, + } + if result.err == nil { + result.count = len(restResult.Collections) + result.rowData = make([]map[string]interface{}, result.count) + for i, coll := range restResult.Collections { + result.rowData[i] = coll.toMap() } } switch restResult.StatusCode { case 403: - err = ErrForbidden + result.err = ErrForbidden case 404: - err = ErrNotFound - } - return rows, err -} - -// RowsListCollections captures the result from LIST COLLECTIONS operation. -type RowsListCollections struct { - count int - collections []CollInfo - cursorCount int -} - -// Columns implements driver.Rows.Columns. -func (r *RowsListCollections) Columns() []string { - return []string{"id", "indexingPolicy", "_rid", "_ts", "_self", "_etag", "_docs", "_sprocs", "_triggers", "_udfs", "_conflicts"} -} - -// Close implements driver.Rows.Close. -func (r *RowsListCollections) Close() error { - return nil -} - -// Next implements driver.Rows.Next. -func (r *RowsListCollections) Next(dest []driver.Value) error { - if r.cursorCount >= r.count { - return io.EOF + result.err = ErrNotFound } - rowData := r.collections[r.cursorCount] - r.cursorCount++ - dest[0] = rowData.Id - dest[1] = rowData.IndexingPolicy - dest[2] = rowData.Rid - dest[3] = rowData.Ts - dest[4] = rowData.Self - dest[5] = rowData.Etag - dest[6] = rowData.Docs - dest[7] = rowData.Sprocs - dest[8] = rowData.Triggers - dest[9] = rowData.Udfs - dest[10] = rowData.Conflicts - return nil + return result, result.err } diff --git a/stmt_collection_test.go b/stmt_collection_test.go new file mode 100644 index 0000000..035b376 --- /dev/null +++ b/stmt_collection_test.go @@ -0,0 +1,592 @@ +package gocosmos + +import ( + "encoding/json" + "fmt" + "testing" +) + +func TestStmtCreateCollection_Query(t *testing.T) { + testName := "TestStmtCreateCollection_Query" + db := _openDb(t, testName) + _, err := db.Query("CREATE COLLECTION dbtemp.tbltemp WITH pk=/id") + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) + } +} + +func TestStmtCreateCollection_Exec(t *testing.T) { + testName := "TestStmtCreateCollection_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustConflict bool + mustNotFound bool + affectedRows int64 + }{ + { + name: "create_new", + initSqls: []string{"DROP DATABASE IF EXISTS db_not_exists", fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)}, + sql: fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/id WITH ru=400", dbname), + affectedRows: 1, + }, + { + name: "create_conflict", + sql: fmt.Sprintf("CREATE TABLE %s.tbltemp WITH pk=/id WITH ru=400", dbname), + mustConflict: true, + }, + { + name: "create_if_not_exists", + sql: fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.tbltemp WITH largepk=/a/b/c WITH maxru=4000 WITH uk=/a;/b,/c/d", dbname), + affectedRows: 0, + }, + { + name: "create_if_not_exists2", + sql: fmt.Sprintf("CREATE COLLECTION IF NOT EXISTS %s.tbltemp1 WITH largepk=/a/b/c WITH maxru=4000 WITH uk=/a;/b,/c/d", dbname), + affectedRows: 1, + }, + { + name: "create_not_found", + sql: "CREATE COLLECTION db_not_exists.table WITH pk=/a", + mustNotFound: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + +func TestStmtCreateCollection_Exec_DefaultDb(t *testing.T) { + testName := "TestStmtCreateCollection_Exec_DefaultDb" + dbname := "dbdefault" + db := _openDefaultDb(t, testName, dbname) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustConflict bool + mustNotFound bool + affectedRows int64 + }{ + { + name: "create_new", + initSqls: []string{fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)}, + sql: "CREATE COLLECTION tbltemp WITH pk=/id WITH ru=400", + affectedRows: 1, + }, + { + name: "create_conflict", + sql: "CREATE TABLE tbltemp WITH pk=/id WITH ru=400", + mustConflict: true, + }, + { + name: "create_if_not_exists", + sql: "CREATE TABLE IF NOT EXISTS tbltemp WITH largepk=/a/b/c WITH maxru=4000 WITH uk=/a;/b,/c/d", + affectedRows: 0, + }, + { + name: "create_if_not_exists2", + sql: "CREATE COLLECTION IF NOT EXISTS tbltemp1 WITH largepk=/a/b/c WITH maxru=4000 WITH uk=/a;/b,/c/d", + affectedRows: 1, + }, + { + name: "create_not_found", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)}, + sql: "CREATE COLLECTION table WITH pk=/a", + mustNotFound: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + +func TestStmtAlterCollection_Query(t *testing.T) { + testName := "TestStmtAlterCollection_Query" + db := _openDb(t, testName) + _, err := db.Query("ALTER COLLECTION dbtemp.tbltemp WITH ru=400") + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) + } +} + +func TestStmtAlterCollection_Exec(t *testing.T) { + testName := "TestStmtAlterCollection_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustConflict bool + mustNotFound bool + affectedRows int64 + }{ + { + name: "change_ru", + initSqls: []string{"DROP DATABASE IF EXISTS db_not_exists", fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname), fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/id", dbname)}, + sql: fmt.Sprintf("ALTER COLLECTION %s.tbltemp WITH ru=500", dbname), + affectedRows: 1, + }, + { + name: "change_maxru", + sql: fmt.Sprintf("ALTER TABLE %s.tbltemp WITH maxru=6000", dbname), + affectedRows: 1, + }, + { + name: "collection_not_found", + sql: fmt.Sprintf("ALTER COLLECTION %s.tbl_not_found WITH ru=400", dbname), + mustNotFound: true, + }, + { + name: "db_not_found", + sql: "ALTER COLLECTION db_not_exists.table WITH ru=400", + mustNotFound: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + +func TestStmtAlterCollection_Exec_DefaultDb(t *testing.T) { + testName := "TestStmtAlterCollection_Exec_DefaultDb" + dbname := "dbdefault" + db := _openDefaultDb(t, testName, dbname) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustConflict bool + mustNotFound bool + affectedRows int64 + }{ + { + name: "change_ru", + initSqls: []string{"DROP DATABASE IF EXISTS db_not_exists", fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname), fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/id", dbname)}, + sql: "ALTER COLLECTION tbltemp WITH ru=500", + affectedRows: 1, + }, + { + name: "change_maxru", + sql: "ALTER TABLE tbltemp WITH maxru=6000", + affectedRows: 1, + }, + { + name: "collection_not_found", + sql: "ALTER COLLECTION tbl_not_found WITH ru=400", + mustNotFound: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + +func TestStmtDropCollection_Query(t *testing.T) { + testName := "TestStmtDropCollection_Query" + db := _openDb(t, testName) + _, err := db.Query("DROP COLLECTION dbtemp.tbltemp") + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) + } +} + +func TestStmtDropCollection_Exec(t *testing.T) { + testName := "TestStmtDropCollection_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustConflict bool + mustNotFound bool + affectedRows int64 + }{ + { + name: "basic", + initSqls: []string{"DROP DATABASE IF EXISTS db_not_exists", fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname), fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/id", dbname)}, + sql: fmt.Sprintf("DROP COLLECTION %s.tbltemp", dbname), + affectedRows: 1, + }, + { + name: "not_found", + sql: fmt.Sprintf("DROP TABLE %s.tbltemp", dbname), + mustNotFound: true, + }, + { + name: "if_exists", + sql: fmt.Sprintf("DROP COLLECTION IF EXISTS %s.tbltemp", dbname), + affectedRows: 0, + }, + { + name: "db_not_found", + sql: "DROP TABLE db_not_exists.table", + mustNotFound: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + }) + } +} + +func TestStmtDropCollection_Exec_DefaultDb(t *testing.T) { + testName := "TestStmtDropCollection_Exec_DefaultDb" + dbname := "dbdefault" + db := _openDefaultDb(t, testName, dbname) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + mustConflict bool + mustNotFound bool + affectedRows int64 + }{ + { + name: "basic", + initSqls: []string{"DROP DATABASE IF EXISTS db_not_exists", fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname), fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/id", dbname)}, + sql: "DROP COLLECTION tbltemp", + affectedRows: 1, + }, + { + name: "not_found", + sql: "DROP TABLE tbltemp", + mustNotFound: true, + }, + { + name: "if_exists", + sql: "DROP COLLECTION IF EXISTS tbltemp", + affectedRows: 0, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + }) + } +} + +func TestStmtListCollections_Exec(t *testing.T) { + testName := "TestStmtListCollections_Exec" + db := _openDb(t, testName) + _, err := db.Exec("LIST COLLECTIONS FROM dbtemp") + if err != ErrExecNotSupported { + t.Fatalf("%s failed: expected ErrExecNotSupported, but received %#v", testName, err) + } +} + +func TestStmtListCollections_Query(t *testing.T) { + testName := "TestStmtListCollections_Query" + db := _openDb(t, testName) + dbname := "dbtemp" + db.Exec("DROP DATABASE IF EXISTS db_not_found") + collNames := []string{"tbltemp", "tbltemp2", "tbltemp1"} + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + for _, collName := range collNames { + db.Exec(fmt.Sprintf("CREATE COLLECTION %s.%s WITH pk=/id", dbname, collName)) + } + + dbRows, err := db.Query(fmt.Sprintf("LIST COLLECTIONS FROM %s", dbname)) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/query", err) + } + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/fetch_rows", err) + } + ok0, ok1, ok2 := false, false, false + for _, row := range rows { + if row["id"] == "tbltemp" { + ok0 = true + } + if row["id"] == "tbltemp1" { + ok1 = true + } + if row["id"] == "tbltemp2" { + ok2 = true + } + } + if !ok0 { + t.Fatalf("%s failed: collection %s not found", testName, "tbltemp") + } + if !ok1 { + t.Fatalf("%s failed: collection %s not found", testName, "tbltemp1") + } + if !ok2 { + t.Fatalf("%s failed: collection %s not found", testName, "tbltemp2") + } + + _, err = db.Query("LIST COLLECTIONS FROM db_not_found") + if err != ErrNotFound { + t.Fatalf("%s failed: expected ErrNotFound but received %#v", testName, err) + } +} + +func TestStmtListCollections_Query_DefaultDb(t *testing.T) { + testName := "TestStmtListCollections_Query_DefaultDb" + dbname := "dbdefault" + db := _openDefaultDb(t, testName, dbname) + db.Exec("DROP DATABASE IF EXISTS db_not_found") + collNames := []string{"tbltemp", "tbltemp2", "tbltemp1"} + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname)) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + for _, collName := range collNames { + db.Exec(fmt.Sprintf("CREATE COLLECTION %s.%s WITH pk=/id", dbname, collName)) + } + + dbRows, err := db.Query("LIST TABLES") + if err != nil { + t.Fatalf("%s failed: %s", testName+"/query", err) + } + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/fetch_rows", err) + } + ok0, ok1, ok2 := false, false, false + for _, row := range rows { + if row["id"] == "tbltemp" { + ok0 = true + } + if row["id"] == "tbltemp1" { + ok1 = true + } + if row["id"] == "tbltemp2" { + ok2 = true + } + } + if !ok0 { + t.Fatalf("%s failed: collection %s not found", testName, "tbltemp") + } + if !ok1 { + t.Fatalf("%s failed: collection %s not found", testName, "tbltemp1") + } + if !ok2 { + t.Fatalf("%s failed: collection %s not found", testName, "tbltemp2") + } + + _, err = db.Query("LIST COLLECTIONS FROM db_not_found") + if err != ErrNotFound { + t.Fatalf("%s failed: expected ErrNotFound but received %#v", testName, err) + } +} diff --git a/stmt_database.go b/stmt_database.go index 8d1fd3a..0fb4d81 100644 --- a/stmt_database.go +++ b/stmt_database.go @@ -7,7 +7,7 @@ import ( "strconv" ) -// StmtCreateDatabase implements "CREATE DATABASE" operation. +// StmtCreateDatabase implements "CREATE DATABASE" statement. // // Syntax: // @@ -52,14 +52,13 @@ func (s *StmtCreateDatabase) validate() error { return nil } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. func (s *StmtCreateDatabase) Query(_ []driver.Value) (driver.Rows, error) { return nil, ErrQueryNotSupported } -// Exec implements driver.Stmt.Exec. -// Upon successful call, this function return (*ResultCreateDatabase, nil). +// Exec implements driver.Stmt/Exec. func (s *StmtCreateDatabase) Exec(_ []driver.Value) (driver.Result, error) { restResult := s.conn.restClient.CreateDatabase(DatabaseSpec{Id: s.dbName, Ru: s.ru, MaxRu: s.maxru}) ignoreErrorCode := 0 @@ -72,7 +71,7 @@ func (s *StmtCreateDatabase) Exec(_ []driver.Value) (driver.Result, error) { /*----------------------------------------------------------------------*/ -// StmtAlterDatabase implements "ALTER DATABASE" operation. +// StmtAlterDatabase implements "ALTER DATABASE" statement. // // Syntax: // @@ -116,14 +115,13 @@ func (s *StmtAlterDatabase) validate() error { return nil } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. func (s *StmtAlterDatabase) Query(_ []driver.Value) (driver.Rows, error) { return nil, ErrQueryNotSupported } -// Exec implements driver.Stmt.Exec. -// Upon successful call, this function return (*ResultAlterDatabase, nil). +// Exec implements driver.Stmt/Exec. func (s *StmtAlterDatabase) Exec(_ []driver.Value) (driver.Result, error) { getResult := s.conn.restClient.GetDatabase(s.dbName) if err := getResult.Error(); err != nil { @@ -142,7 +140,7 @@ func (s *StmtAlterDatabase) Exec(_ []driver.Value) (driver.Result, error) { /*----------------------------------------------------------------------*/ -// StmtDropDatabase implements "DROP DATABASE" operation. +// StmtDropDatabase implements "DROP DATABASE" statement. // // Syntax: // @@ -159,14 +157,13 @@ func (s *StmtDropDatabase) validate() error { return nil } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. func (s *StmtDropDatabase) Query(_ []driver.Value) (driver.Rows, error) { return nil, ErrQueryNotSupported } -// Exec implements driver.Stmt.Exec. -// This function always return a nil driver.Result. +// Exec implements driver.Stmt/Exec. func (s *StmtDropDatabase) Exec(_ []driver.Value) (driver.Result, error) { restResult := s.conn.restClient.DeleteDatabase(s.dbName) ignoreErrorCode := 0 @@ -179,7 +176,7 @@ func (s *StmtDropDatabase) Exec(_ []driver.Value) (driver.Result, error) { /*----------------------------------------------------------------------*/ -// StmtListDatabases implements "LIST DATABASES" operation. +// StmtListDatabases implements "LIST DATABASES" statement. // // Syntax: // @@ -192,13 +189,13 @@ func (s *StmtListDatabases) validate() error { return nil } -// Exec implements driver.Stmt.Exec. +// Exec implements driver.Stmt/Exec. // This function is not implemented, use Query instead. func (s *StmtListDatabases) Exec(_ []driver.Value) (driver.Result, error) { return nil, ErrExecNotSupported } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. func (s *StmtListDatabases) Query(_ []driver.Value) (driver.Rows, error) { restResult := s.conn.restClient.ListDatabases() result := &ResultResultSet{ @@ -218,37 +215,3 @@ func (s *StmtListDatabases) Query(_ []driver.Value) (driver.Rows, error) { } return result, result.err } - -// // RowsListDatabases captures the result from LIST DATABASES operation. -// type RowsListDatabases struct { -// count int -// databases []DbInfo -// cursorCount int -// } -// -// // Columns implements driver.Rows.Columns. -// func (r *RowsListDatabases) Columns() []string { -// return []string{"id", "_rid", "_ts", "_self", "_etag", "_colls", "_users"} -// } -// -// // Close implements driver.Rows.Close. -// func (r *RowsListDatabases) Close() error { -// return nil -// } -// -// // Next implements driver.Rows.Next. -// func (r *RowsListDatabases) Next(dest []driver.Value) error { -// if r.cursorCount >= r.count { -// return io.EOF -// } -// rowData := r.databases[r.cursorCount] -// r.cursorCount++ -// dest[0] = rowData.Id -// dest[1] = rowData.Rid -// dest[2] = rowData.Ts -// dest[3] = rowData.Self -// dest[4] = rowData.Etag -// dest[5] = rowData.Colls -// dest[6] = rowData.Users -// return nil -// } diff --git a/stmt_database_test.go b/stmt_database_test.go index 5efd277..b7c2d7f 100644 --- a/stmt_database_test.go +++ b/stmt_database_test.go @@ -11,7 +11,7 @@ func TestStmtCreateDatabase_Query(t *testing.T) { db := _openDb(t, testName) _, err := db.Query("CREATE DATABASE dbtemp") if err != ErrQueryNotSupported { - t.Fatalf("%s failed: expected ErrOperationNotSupported, but received %#v", testName, err) + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) } } @@ -96,7 +96,7 @@ func TestStmtAlterDatabase_Query(t *testing.T) { db := _openDb(t, testName) _, err := db.Query("ALTER DATABASE dbtemp WITH ru=400") if err != ErrQueryNotSupported { - t.Fatalf("%s failed: expected ErrOperationNotSupported, but received %#v", testName, err) + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) } } @@ -181,7 +181,7 @@ func TestStmtDropDatabase_Query(t *testing.T) { db := _openDb(t, testName) _, err := db.Query("DROP DATABASE dbtemp") if err != ErrQueryNotSupported { - t.Fatalf("%s failed: expected ErrOperationNotSupported, but received %#v", testName, err) + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) } } @@ -255,9 +255,9 @@ func TestStmtDropDatabase_Exec(t *testing.T) { func TestStmtListDatabases_Exec(t *testing.T) { testName := "TestStmtListDatabases_Exec" db := _openDb(t, testName) - _, err := db.Query("LIST DATABASES") + _, err := db.Exec("LIST DATABASES") if err != ErrExecNotSupported { - t.Fatalf("%s failed: expected ErrOperationNotSupported, but received %#v", testName, err) + t.Fatalf("%s failed: expected ErrExecNotSupported, but received %#v", testName, err) } } From 6369a6c0da5e64f637df00eeb1cc9b462daba50c Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Mon, 5 Jun 2023 18:46:24 +1000 Subject: [PATCH 06/13] refactor StmtInsert and update tests --- .github/workflows/gocosmos.yaml | 26 +++ gocosmos_test.go | 92 ---------- stmt_document.go | 48 +----- stmt_document_test.go | 294 ++++++++++++++++++++++++++++++++ 4 files changed, 329 insertions(+), 131 deletions(-) create mode 100644 stmt_document_test.go diff --git a/.github/workflows/gocosmos.yaml b/.github/workflows/gocosmos.yaml index 18581d4..7545a87 100644 --- a/.github/workflows/gocosmos.yaml +++ b/.github/workflows/gocosmos.yaml @@ -78,6 +78,32 @@ jobs: flags: driver_collection name: driver_collection + testDriverStmtDocumentNonQuery: + name: Test driver document non-query statements + runs-on: windows-latest + steps: + - name: Set up Go env + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + - name: Test + run: | + choco install azure-cosmosdb-emulator + & "C:\Program Files\Azure Cosmos DB Emulator\Microsoft.Azure.Cosmos.Emulator.exe" /DisableRateLimiting /NoUI /NoExplorer + Start-Sleep -s 60 + try { Invoke-RestMethod -Method GET https://127.0.0.1:8081/ } catch {} + netstat -nt + $env:COSMOSDB_DRIVER='gocosmos' + $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' + go test -cover -coverprofile="coverage_driver_document_non_query.txt" -v -count 1 -p 1 -run "TestStmt(Insert|Upsert|Update|Delete)_(Exec|Query)" . + - name: Codecov + uses: codecov/codecov-action@v3 + with: + flags: driver_document_non_query + name: driver_document_non_query + testDriverSelect: name: Test driver SELECT query runs-on: windows-latest diff --git a/gocosmos_test.go b/gocosmos_test.go index ad222b4..613c52b 100644 --- a/gocosmos_test.go +++ b/gocosmos_test.go @@ -128,98 +128,6 @@ func TestDriver_Close(t *testing.T) { /*----------------------------------------------------------------------*/ -func Test_Query_Insert(t *testing.T) { - name := "Test_Query_Insert" - db := _openDb(t, name) - _, err := db.Query("INSERT INTO db.table (a,b,c) VALUES (1,2,3)", nil) - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_Insert(t *testing.T) { - name := "Test_Exec_Insert" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - if result, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user\"", "\"user@domain1.com\"", 7, true)`, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user\"", "\"user@domain2.com\"", 8, false)`, "user"); err != ErrConflict { - // duplicated id (in logical partition scope) - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } - - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email,grade,actived) VALUES ("\"2\"", "\"user\"", "\"user@domain1.com\"", 9, false)`, "user"); err != ErrConflict { - // duplicated unique index (in logical partition scope) - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } - - if _, err := db.Exec(`INSERT INTO db_not_exists.table (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, "y"); err != ErrNotFound { - // database/table not found - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - if _, err := db.Exec(`INSERT INTO dbtemp.tbl_not_found (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, "y"); err != ErrNotFound { - // database/table not found - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } -} - -func Test_Exec_InsertPlaceholder(t *testing.T) { - name := "Test_Exec_InsertPlaceholder" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - if result, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, - "1", "user", "user@domain1.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, - "1", "user", "user@domain2.com", 2, false, nil, "user"); err != ErrConflict { - // duplicated id (in logical partition scope) - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } - - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, - "2", "user", "user@domain1.com", 3, false, nil, "user"); err != ErrConflict { - // duplicated unique index (in logical partition scope) - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } - - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :10)`, - "9", "user", "user@domain.com", 9, false, nil, "user"); err == nil || strings.Index(err.Error(), "invalid value index") < 0 { - t.Fatalf("%s failed: expected 'invalid value index' bur received %#v", name, err) - } -} - func Test_Query_Upsert(t *testing.T) { name := "Test_Query_Upsert" db := _openDb(t, name) diff --git a/stmt_document.go b/stmt_document.go index f05d393..6126b32 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -130,10 +130,9 @@ func (s *StmtInsert) validate() error { return nil } -// Exec implements driver.Stmt.Exec. -// Upon successful call, this function returns (*ResultInsert, nil). +// Exec implements driver.Stmt/Exec. // -// Note: this function expects the last argument is partition key value. +// Note: this function expects the _last_ argument is _partition_ key value. func (s *StmtInsert) Exec(args []driver.Value) (driver.Result, error) { spec := DocumentSpec{ DbName: s.dbName, @@ -155,47 +154,18 @@ func (s *StmtInsert) Exec(args []driver.Value) (driver.Result, error) { } } restResult := s.conn.restClient.CreateDocument(spec) - result := &ResultInsert{Successful: restResult.Error() == nil} + rid := "" if restResult.DocInfo != nil { - result.InsertId, _ = restResult.DocInfo["_rid"].(string) + rid, _ = restResult.DocInfo["_rid"].(string) } - err := restResult.Error() - switch restResult.StatusCode { - case 403: - err = ErrForbidden - case 404: - err = ErrNotFound - case 409: - err = ErrConflict - } - return result, err + result := buildResultNoResultSet(&restResult.RestReponse, true, rid, 0) + return result, result.err } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. -func (s *StmtInsert) Query(args []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") -} - -// ResultInsert captures the result from INSERT operation. -type ResultInsert struct { - // Successful flags if the operation was successful or not. - Successful bool - // InsertId holds the "_rid" if the operation was successful. - InsertId string -} - -// LastInsertId implements driver.Result.LastInsertId. -func (r *ResultInsert) LastInsertId() (int64, error) { - return 0, fmt.Errorf("this operation is not supported. {LastInsertId:%s}", r.InsertId) -} - -// RowsAffected implements driver.Result.RowsAffected. -func (r *ResultInsert) RowsAffected() (int64, error) { - if r.Successful { - return 1, nil - } - return 0, nil +func (s *StmtInsert) Query(_ []driver.Value) (driver.Rows, error) { + return nil, ErrQueryNotSupported } /*----------------------------------------------------------------------*/ diff --git a/stmt_document_test.go b/stmt_document_test.go new file mode 100644 index 0000000..aa0e622 --- /dev/null +++ b/stmt_document_test.go @@ -0,0 +1,294 @@ +package gocosmos + +import ( + "encoding/json" + "fmt" + "strings" + "testing" + "time" +) + +func TestStmtInsert_Query(t *testing.T) { + testName := "TestStmtInsert_Query" + db := _openDb(t, testName) + _, err := db.Query("INSERT INTO db.table (a,b,c) VALUES (1,2,3)", nil) + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) + } +} + +func TestStmtInsert_Exec(t *testing.T) { + testName := "TestStmtInsert_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "insert_new", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + }, + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user\"", "\"user@domain1.com\"", 7, true)`, dbname), + args: []interface{}{"user"}, + affectedRows: 1, + }, + { + name: "insert_conflict_pk", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user\"", "\"user@domain2.com\"", 8, false)`, dbname), + args: []interface{}{"user"}, + mustConflict: true, + }, + { + name: "insert_conflict_uk", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"2\"", "\"user\"", "\"user@domain1.com\"", 9, false)`, dbname), + args: []interface{}{"user"}, + mustConflict: true, + }, + { + name: "table_not_exists", + sql: fmt.Sprintf(`INSERT INTO %s.tbl_not_found (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, dbname), + args: []interface{}{"y"}, + mustNotFound: true, + }, + { + name: "db_not_exists", + sql: `INSERT INTO db_not_exists.table (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, + args: []interface{}{"y"}, + mustNotFound: true, + }, + { + name: "insert_new_placeholders", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + }, + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + args: []interface{}{"1", "user", "user@domain1.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user"}, + affectedRows: 1, + }, + { + name: "insert_conflict_pk_placeholders", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + args: []interface{}{"1", "user", "user@domain2.com", 2, false, nil, "user"}, + mustConflict: true, + }, + { + name: "insert_conflict_uk_placeholders", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + args: []interface{}{"2", "user", "user@domain1.com", 3, false, nil, "user"}, + mustConflict: true, + }, + { + name: "table_not_exists_placeholders", + sql: fmt.Sprintf(`INSERT INTO %s.tbl_not_found (id,username,email) VALUES (:1, :2, :3)`, dbname), + args: []interface{}{"x", "y", "x", "y"}, + mustNotFound: true, + }, + { + name: "db_not_exists_placeholders", + sql: `INSERT INTO db_not_exists.table (id,username,email) VALUES (@1, @2, @3)`, + args: []interface{}{"x", "y", "x", "y"}, + mustNotFound: true, + }, + { + name: "error_invalid_value_index", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :10)`, dbname), + args: []interface{}{"2", "user", "user@domain1.com", 3, false, nil, "user"}, + mustError: "invalid value index", + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + +func TestStmtInsert_Exec_DefaultDb(t *testing.T) { + testName := "TestStmtInsert_Exec_DefaultDb" + dbname := "dbdefault" + db := _openDefaultDb(t, testName, dbname) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "insert_new", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + }, + sql: `INSERT INTO tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user\"", "\"user@domain1.com\"", 7, true)`, + args: []interface{}{"user"}, + affectedRows: 1, + }, + { + name: "insert_conflict_pk", + sql: `INSERT INTO tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user\"", "\"user@domain2.com\"", 8, false)`, + args: []interface{}{"user"}, + mustConflict: true, + }, + { + name: "insert_conflict_uk", + sql: `INSERT INTO tbltemp (id,username,email,grade,actived) VALUES ("\"2\"", "\"user\"", "\"user@domain1.com\"", 9, false)`, + args: []interface{}{"user"}, + mustConflict: true, + }, + { + name: "table_not_exists", + sql: `INSERT INTO tbl_not_found (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, + args: []interface{}{"y"}, + mustNotFound: true, + }, + { + name: "insert_new_placeholders", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + }, + sql: `INSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + args: []interface{}{"1", "user", "user@domain1.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user"}, + affectedRows: 1, + }, + { + name: "insert_conflict_pk_placeholders", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + args: []interface{}{"1", "user", "user@domain2.com", 2, false, nil, "user"}, + mustConflict: true, + }, + { + name: "insert_conflict_uk_placeholders", + sql: `INSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + args: []interface{}{"2", "user", "user@domain1.com", 3, false, nil, "user"}, + mustConflict: true, + }, + { + name: "table_not_exists_placeholders", + sql: `INSERT INTO tbl_not_found (id,username,email) VALUES (:1, :2, :3)`, + args: []interface{}{"x", "y", "x", "y"}, + mustNotFound: true, + }, + { + name: "error_invalid_value_index", + sql: `INSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :10)`, + args: []interface{}{"2", "user", "user@domain1.com", 3, false, nil, "user"}, + mustError: "invalid value index", + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} From 5c904807d4c9bf7bbb34882f892234bbe39df963 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Tue, 6 Jun 2023 23:29:24 +1000 Subject: [PATCH 07/13] refactor UPSERT and update tests --- gocosmos_test.go | 120 ---------------- stmt_document_test.go | 311 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+), 120 deletions(-) diff --git a/gocosmos_test.go b/gocosmos_test.go index 613c52b..655b715 100644 --- a/gocosmos_test.go +++ b/gocosmos_test.go @@ -4,10 +4,8 @@ import ( "context" "database/sql" "os" - "regexp" "strings" "testing" - "time" ) func Test_OpenDatabase(t *testing.T) { @@ -128,124 +126,6 @@ func TestDriver_Close(t *testing.T) { /*----------------------------------------------------------------------*/ -func Test_Query_Upsert(t *testing.T) { - name := "Test_Query_Upsert" - db := _openDb(t, name) - _, err := db.Query("UPSERT INTO db.table (a,b,c) VALUES (1,2,3)", nil) - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_Upsert(t *testing.T) { - name := "Test_Exec_Upsert" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - if result, err := db.Exec(`UPSERT INTO dbtemp.tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user1\"", "\"user1@domain.com\"", 7, true)`, "user1"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - if result, err := db.Exec(`UPSERT INTO dbtemp.tbltemp (id, username, email, grade, actived) VALUES ("\"2\"", "\"user2\"", "\"user2@domain.com\"", 7, true)`, "user2"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if result, err := db.Exec(`UPSERT INTO dbtemp.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user1\"", "\"user1@domain1.com\"", 8, false)`, "user1"); err != nil { - // duplicated id (in logical partition scope): existing document should be overwritten - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if _, err := db.Exec(`UPSERT INTO dbtemp.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user2\"", "\"user2@domain.com\"", 9, true)`, "user2"); err != ErrConflict { - // duplicated unique index (in logical partition scope) - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } - - if _, err := db.Exec(`UPSERT INTO db_not_exists.table (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, "y"); err != ErrNotFound { - // database/table not found - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - if _, err := db.Exec(`UPSERT INTO dbtemp.tbl_not_found (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, "y"); err != ErrNotFound { - // database/table not found - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } -} - -func Test_Exec_UpsertPlaceholder(t *testing.T) { - name := "Test_Exec_UpsertPlaceholder" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - if result, err := db.Exec(`UPSERT INTO dbtemp.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, - "1", "user1", "user1@domain.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user1"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - if result, err := db.Exec(`UPSERT INTO dbtemp.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, - "2", "user2", "user2@domain.com", 2, false, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user2"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if result, err := db.Exec(`UPSERT INTO dbtemp.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, - "1", "user1", "user2@domain.com", 2, false, nil, "user1"); err != nil { - // duplicated id (in logical partition scope): existing document should be overwritten - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if regexp.MustCompile(`(?i){\s*LastInsertId\s*:\s*[^}]+?\s*}`).FindString(err.Error()) == "" { - t.Fatalf("%s failed: can not catch LastInsertId / %s", name, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if _, err := db.Exec(`UPSERT INTO dbtemp.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, - "2", "user1", "user2@domain.com", 3, false, nil, "user1"); err != ErrConflict { - // duplicated unique index (in logical partition scope) - t.Fatalf("%s failed: expected ErrConflict but received %#v", name, err) - } -} - func Test_Query_Delete(t *testing.T) { name := "Test_Query_Delete" db := _openDb(t, name) diff --git a/stmt_document_test.go b/stmt_document_test.go index aa0e622..538abb8 100644 --- a/stmt_document_test.go +++ b/stmt_document_test.go @@ -292,3 +292,314 @@ func TestStmtInsert_Exec_DefaultDb(t *testing.T) { }) } } + +func TestStmtUpsert_Query(t *testing.T) { + testName := "TestStmtUpsert_Query" + db := _openDb(t, testName) + _, err := db.Query("UPSERT INTO db.table (a,b,c) VALUES (1,2,3)", nil) + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) + } +} + +func TestStmtUpsert_Exec(t *testing.T) { + testName := "TestStmtUpsert_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "upsert_new", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH UK=/email", dbname), + }, + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user1\"", "\"user1@domain.com\"", 7, true)`, dbname), + args: []interface{}{"user1"}, + affectedRows: 1, + }, + { + name: "upsert_another", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived) VALUES ("\"2\"", "\"user2\"", "\"user2@domain.com\"", 7, true)`, dbname), + args: []interface{}{"user2"}, + affectedRows: 1, + }, + { + name: "upsert_duplicated_id_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user1\"", "\"user3@domain1.com\"", 8, false)`, dbname), + args: []interface{}{"user1"}, + affectedRows: 1, + }, + { + name: "upsert_conflict_uk", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"3\"", "\"user2\"", "\"user2@domain.com\"", 9, true)`, dbname), + args: []interface{}{"user2"}, + mustConflict: true, + }, + { + name: "table_not_exists", + sql: fmt.Sprintf(`UPSERT INTO %s.tbl_not_found (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, dbname), + args: []interface{}{"y"}, + mustNotFound: true, + }, + { + name: "db_not_exists", + sql: `UPSERT INTO db_not_exists.table (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, + args: []interface{}{"y"}, + mustNotFound: true, + }, + { + name: "upsert_new_placeholders", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + }, + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + args: []interface{}{"1", "user1", "user1@domain.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user1"}, + affectedRows: 1, + }, + { + name: "upsert_another_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + args: []interface{}{"2", "user2", "user2@domain.com", 2, false, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user2"}, + affectedRows: 1, + }, + { + name: "upsert_duplicated_id_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + args: []interface{}{"1", "user1", "user2@domain.com", 2, false, nil, "user1"}, + affectedRows: 1, + }, + { + name: "upsert_conflict_uk_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + args: []interface{}{"2", "user1", "user2@domain.com", 3, false, nil, "user1"}, + mustConflict: true, + }, + { + name: "table_not_exists_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbl_not_found (id,username,email) VALUES (:1, :2, :3)`, dbname), + args: []interface{}{"x", "y", "x", "y"}, + mustNotFound: true, + }, + { + name: "db_not_exists_placeholders", + sql: `UPSERT INTO db_not_exists.table (id,username,email) VALUES (@1, @2, @3)`, + args: []interface{}{"x", "y", "x", "y"}, + mustNotFound: true, + }, + { + name: "error_invalid_value_index", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :10)`, dbname), + args: []interface{}{"2", "user", "user@domain1.com", 3, false, nil, "user"}, + mustError: "invalid value index", + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + +func TestStmtUpsert_Exec_DefaultDb(t *testing.T) { + testName := "TestStmtUpsert_Exec_DefaultDb" + dbname := "dbdefault" + db := _openDefaultDb(t, testName, dbname) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "upsert_new", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH UK=/email", dbname), + }, + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user1\"", "\"user1@domain.com\"", 7, true)`, + args: []interface{}{"user1"}, + affectedRows: 1, + }, + { + name: "upsert_another", + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived) VALUES ("\"2\"", "\"user2\"", "\"user2@domain.com\"", 7, true)`, + args: []interface{}{"user2"}, + affectedRows: 1, + }, + { + name: "upsert_duplicated_id_placeholders", + sql: `UPSERT INTO tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user1\"", "\"user3@domain1.com\"", 8, false)`, + args: []interface{}{"user1"}, + affectedRows: 1, + }, + { + name: "upsert_conflict_uk", + sql: `UPSERT INTO tbltemp (id,username,email,grade,actived) VALUES ("\"3\"", "\"user2\"", "\"user2@domain.com\"", 9, true)`, + args: []interface{}{"user2"}, + mustConflict: true, + }, + { + name: "table_not_exists", + sql: `UPSERT INTO tbl_not_found (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, + args: []interface{}{"y"}, + mustNotFound: true, + }, + { + name: "upsert_new_placeholders", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + }, + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + args: []interface{}{"1", "user1", "user1@domain.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user1"}, + affectedRows: 1, + }, + { + name: "upsert_another_placeholders", + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + args: []interface{}{"2", "user2", "user2@domain.com", 2, false, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user2"}, + affectedRows: 1, + }, + { + name: "upsert_duplicated_id_placeholders", + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + args: []interface{}{"1", "user1", "user2@domain.com", 2, false, nil, "user1"}, + affectedRows: 1, + }, + { + name: "upsert_conflict_uk_placeholders", + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + args: []interface{}{"2", "user1", "user2@domain.com", 3, false, nil, "user1"}, + mustConflict: true, + }, + { + name: "table_not_exists_placeholders", + sql: `UPSERT INTO tbl_not_found (id,username,email) VALUES (:1, :2, :3)`, + args: []interface{}{"x", "y", "x", "y"}, + mustNotFound: true, + }, + { + name: "error_invalid_value_index", + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :10)`, + args: []interface{}{"2", "user", "user@domain1.com", 3, false, nil, "user"}, + mustError: "invalid value index", + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} From fde18de705721c51eaedf87469e3796abef307e0 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Wed, 7 Jun 2023 01:33:38 +1000 Subject: [PATCH 08/13] refactor DELETE and update tests --- gocosmos_test.go | 86 -------------- stmt_document.go | 52 +++------ stmt_document_test.go | 255 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 268 insertions(+), 125 deletions(-) diff --git a/gocosmos_test.go b/gocosmos_test.go index 655b715..5c781ff 100644 --- a/gocosmos_test.go +++ b/gocosmos_test.go @@ -126,92 +126,6 @@ func TestDriver_Close(t *testing.T) { /*----------------------------------------------------------------------*/ -func Test_Query_Delete(t *testing.T) { - name := "Test_Query_Delete" - db := _openDb(t, name) - _, err := db.Query("DELETE FROM db.table WHERE id=1", nil) - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_Delete(t *testing.T) { - name := "Test_Exec_Delete" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email) VALUES (:1,@2,$3)`, "1", "user", "user@domain1.com", "user") - db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email) VALUES (:1,@2,$3)`, "2", "user", "user@domain2.com", "user") - db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email) VALUES (:1,@2,$3)`, "3", "user", "user@domain3.com", "user") - db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email) VALUES (:1,@2,$3)`, "4", "user", "user@domain4.com", "user") - db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email) VALUES (:1,@2,$3)`, "5", "user", "user@domain5.com", "user") - - if dbResult, err := db.Exec(`DELETE FROM dbtemp.tbltemp WHERE id=1`, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := dbResult.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := dbResult.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if dbResult, err := db.Exec(`DELETE FROM dbtemp.tbltemp WHERE id="2"`, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := dbResult.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := dbResult.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if dbResult, err := db.Exec(`DELETE FROM dbtemp.tbltemp WHERE id=:1`, "3", "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := dbResult.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := dbResult.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - if dbResult, err := db.Exec(`DELETE FROM dbtemp.tbltemp WHERE id=@1`, "4", "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := dbResult.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := dbResult.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - if dbResult, err := db.Exec(`DELETE FROM dbtemp.tbltemp WHERE id=$1`, "5", "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := dbResult.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := dbResult.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if dbResult, err := db.Exec(`DELETE FROM dbtemp.tbltemp WHERE id=1`, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := dbResult.LastInsertId(); id != 0 && err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := dbResult.RowsAffected(); numRows != 0 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=0/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if _, err := db.Exec(`DELETE FROM dbtemp.table_not_exists WHERE id=1`, "user"); err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - - if _, err := db.Exec(`DELETE FROM db_not_exists.table WHERE id=1`, "user"); err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - - if _, err := db.Exec(`DELETE FROM dbtemp.tbltemp WHERE id=$10`, "1", "user"); err == nil || strings.Index(err.Error(), "invalid value index") < 0 { - t.Fatalf("%s failed: expected 'invalid value index' bur received %#v", name, err) - } -} - func Test_Query_Update(t *testing.T) { name := "Test_Query_Update" db := _openDb(t, name) diff --git a/stmt_document.go b/stmt_document.go index 6126b32..b368f38 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -199,9 +199,6 @@ func (s *StmtDelete) parse() error { } else if loc := reValPlaceholder.FindStringIndex(s.idStr); loc != nil { if loc[0] == 0 && loc[1] == len(s.idStr) { index, _ := strconv.Atoi(s.idStr[loc[0]+1:]) - // if err != nil || index < 1 { - // return fmt.Errorf("invalid id placeholder literate: %s", s.idStr) - // } s.id = placeholder{index} s.numInput++ } else { @@ -221,10 +218,9 @@ func (s *StmtDelete) validate() error { return nil } -// Exec implements driver.Stmt.Exec. -// This function always return nil driver.Result. +// Exec implements driver.Stmt/Exec. // -// Note: this function expects the last argument is partition key value. +// Note: this function expects the _last_ argument is _partition_ key value. func (s *StmtDelete) Exec(args []driver.Value) (driver.Result, error) { id := s.idStr if s.id != nil { @@ -234,51 +230,29 @@ func (s *StmtDelete) Exec(args []driver.Value) (driver.Result, error) { } id = fmt.Sprintf("%s", args[ph.index-1]) } - restClient := s.conn.restClient.DeleteDocument(DocReq{DbName: s.dbName, CollName: s.collName, DocId: id, + restResult := s.conn.restClient.DeleteDocument(DocReq{DbName: s.dbName, CollName: s.collName, DocId: id, PartitionKeyValues: []interface{}{args[s.numInput-1]}, // expect the last argument is partition key value }) - err := restClient.Error() - result := &ResultDelete{Successful: err == nil, StatusCode: restClient.StatusCode} - switch restClient.StatusCode { + result := buildResultNoResultSet(&restResult.RestReponse, false, "", 0) + switch restResult.StatusCode { case 403: - err = ErrForbidden + result.err = ErrForbidden case 404: // consider "document not found" as successful operation // but database/collection not found is not! - if strings.Index(fmt.Sprintf("%s", err), "ResourceType: Document") >= 0 { - err = nil + if strings.Index(fmt.Sprintf("%s", restResult.Error()), "ResourceType: Document") >= 0 { + result.err = nil } else { - err = ErrNotFound + result.err = ErrNotFound } } - return result, err + return result, result.err } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. -func (s *StmtDelete) Query(args []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") -} - -// ResultDelete captures the result from DELETE operation. -type ResultDelete struct { - // Successful flags if the operation was successful or not. - Successful bool - // StatusCode is the HTTP status code returned from CosmosDB. - StatusCode int -} - -// LastInsertId implements driver.Result.LastInsertId. -func (r *ResultDelete) LastInsertId() (int64, error) { - return 0, errors.New("this operation is not supported") -} - -// RowsAffected implements driver.Result.RowsAffected. -func (r *ResultDelete) RowsAffected() (int64, error) { - if r.Successful && r.StatusCode < 400 { - return 1, nil - } - return 0, nil +func (s *StmtDelete) Query(_ []driver.Value) (driver.Rows, error) { + return nil, ErrQueryNotSupported } /*----------------------------------------------------------------------*/ diff --git a/stmt_document_test.go b/stmt_document_test.go index 538abb8..ad18962 100644 --- a/stmt_document_test.go +++ b/stmt_document_test.go @@ -603,3 +603,258 @@ func TestStmtUpsert_Exec_DefaultDb(t *testing.T) { }) } } + +func TestStmtDelete_Query(t *testing.T) { + testName := "TestStmtDelete_Query" + db := _openDb(t, testName) + _, err := db.Query("DELETE FROM db.table WHERE id=1", nil) + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) + } +} + +func TestStmtDelete_Exec(t *testing.T) { + testName := "TestStmtDelete_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + initParams [][]interface{} + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "delete_1", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain1.com", "user"}, {"2", "user", "user@domain2.com", "user"}, + {"3", "user", "user@domain3.com", "user"}, {"4", "user", "user@domain4.com", "user"}, {"5", "user", "user@domain5.com", "user"}}, + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=1`, dbname), + args: []interface{}{"user"}, + affectedRows: 1, + }, + { + name: "delete_2", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id="2"`, dbname), + args: []interface{}{"user"}, + affectedRows: 1, + }, + { + name: "delete_3", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=:1`, dbname), + args: []interface{}{"3", "user"}, + affectedRows: 1, + }, + { + name: "delete_4", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=@1`, dbname), + args: []interface{}{"4", "user"}, + affectedRows: 1, + }, + { + name: "delete_5", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=$1`, dbname), + args: []interface{}{"5", "user"}, + affectedRows: 1, + }, + { + name: "row_not_exists", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=1`, dbname), + args: []interface{}{"user"}, + affectedRows: 0, + }, + { + name: "table_not_exists", + sql: fmt.Sprintf(`DELETE FROM %s.table_not_exists WHERE id=1`, dbname), + args: []interface{}{"user"}, + mustNotFound: true, + }, + { + name: "db_not_exists", + sql: `DELETE FROM db_not_exists.table WHERE id=1`, + args: []interface{}{"user"}, + mustNotFound: true, + }, + { + name: "error_invalid_value_index", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=$9`, dbname), + args: []interface{}{"1", "user"}, + mustError: "invalid value index", + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for i, initSql := range testCase.initSqls { + var params []interface{} + if len(testCase.initParams) > i { + params = testCase.initParams[i] + } + _, err := db.Exec(initSql, params...) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + }) + } +} + +func TestStmtDelete_Exec_DefaultDb(t *testing.T) { + testName := "TestStmtDelete_Exec_DefaultDb" + dbname := "dbdefault" + db := _openDefaultDb(t, testName, dbname) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + initParams [][]interface{} + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "delete_1", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email) VALUES (:1,:2,:3)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain1.com", "user"}, {"2", "user", "user@domain2.com", "user"}, + {"3", "user", "user@domain3.com", "user"}, {"4", "user", "user@domain4.com", "user"}, {"5", "user", "user@domain5.com", "user"}}, + sql: `DELETE FROM tbltemp WHERE id=1`, + args: []interface{}{"user"}, + affectedRows: 1, + }, + { + name: "delete_2", + sql: `DELETE FROM tbltemp WHERE id="2"`, + args: []interface{}{"user"}, + affectedRows: 1, + }, + { + name: "delete_3", + sql: `DELETE FROM tbltemp WHERE id=:1`, + args: []interface{}{"3", "user"}, + affectedRows: 1, + }, + { + name: "delete_4", + sql: `DELETE FROM tbltemp WHERE id=@1`, + args: []interface{}{"4", "user"}, + affectedRows: 1, + }, + { + name: "delete_5", + sql: `DELETE FROM tbltemp WHERE id=$1`, + args: []interface{}{"5", "user"}, + affectedRows: 1, + }, + { + name: "row_not_exists", + sql: `DELETE FROM tbltemp WHERE id=1`, + args: []interface{}{"user"}, + affectedRows: 0, + }, + { + name: "table_not_exists", + sql: `DELETE FROM table_not_exists WHERE id=1`, + args: []interface{}{"user"}, + mustNotFound: true, + }, + { + name: "error_invalid_value_index", + sql: `DELETE FROM tbltemp WHERE id=$9`, + args: []interface{}{"1", "user"}, + mustError: "invalid value index", + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for i, initSql := range testCase.initSqls { + var params []interface{} + if len(testCase.initParams) > i { + params = testCase.initParams[i] + } + _, err := db.Exec(initSql, params...) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + }) + } +} From 21a19a33a77433d354348f428087e2eebde825d0 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Wed, 7 Jun 2023 17:49:31 +1000 Subject: [PATCH 09/13] refactor UPDATE and update tests --- driver.go | 6 + gocosmos_test.go | 139 ------------------ stmt.go | 6 + stmt_document.go | 69 ++------- stmt_document_test.go | 319 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 344 insertions(+), 195 deletions(-) diff --git a/driver.go b/driver.go index 942137d..315e02a 100644 --- a/driver.go +++ b/driver.go @@ -66,6 +66,12 @@ var ( // ErrConflict is returned when the executing operation cause conflict (e.g. duplicated id). ErrConflict = errors.New("StatusCode=409 Conflict") + // ErrPreconditionFailure is returned when operation specified an eTag that is different from the version available + // at the server, that is, an optimistic concurrency error. + // + // @Available since v0.3.0 + ErrPreconditionFailure = errors.New("StatusCode=412 Precondition failure") + // ErrOperationNotSupported is returned to indicate that the operation is not supported. // // @Available since v0.3.0 diff --git a/gocosmos_test.go b/gocosmos_test.go index 5c781ff..f6dc0ea 100644 --- a/gocosmos_test.go +++ b/gocosmos_test.go @@ -123,142 +123,3 @@ func TestDriver_Close(t *testing.T) { t.Fatalf("%s failed: %s", name, err) } } - -/*----------------------------------------------------------------------*/ - -func Test_Query_Update(t *testing.T) { - name := "Test_Query_Update" - db := _openDb(t, name) - _, err := db.Query("UPDATE db.table SET a=1 WHERE id=2", nil) - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -func Test_Exec_Update(t *testing.T) { - name := "Test_Exec_Update" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, - "1", "user", "user@domain.com", 1, true, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, - "2", "user", "user2@domain.com", 1, true, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - if result, err := db.Exec(`UPDATE dbtemp.tbltemp SET grade=2.0,active=false,data="\"a string 'with' \\\"quote\\\"\"" WHERE id=1`, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if result, err := db.Exec(`UPDATE dbtemp.tbltemp SET username="\"user1\"" WHERE id=1`, "user1"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := result.RowsAffected(); numRows != 0 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=0/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if _, err := db.Exec(`UPDATE dbtemp.tbltemp SET email="\"user2@domain.com\"" WHERE id=1`, "user"); err != ErrConflict { - t.Fatalf("%s failed: %s", name, err) - } - - if _, err := db.Exec(`UPDATE dbtemp.tbl_not_found SET email="\"user2@domain.com\"" WHERE id=1`, "user"); err != ErrNotFound { - t.Fatalf("%s failed: %s", name, err) - } - - if _, err := db.Exec(`UPDATE db_not_exists.tbltemp SET email="\"user2@domain.com\"" WHERE id=1`, "user"); err != ErrNotFound { - t.Fatalf("%s failed: %s", name, err) - } - - // can not change document id - if result, err := db.Exec(`UPDATE dbtemp.tbltemp SET id="\"0\"" WHERE id=1`, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := result.RowsAffected(); numRows != 0 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=0/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } -} - -func Test_Exec_UpdatePlaceholder(t *testing.T) { - name := "Test_Exec_UpdatePlaceholder" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, - "1", "user", "user@domain.com", 1, true, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if _, err := db.Exec(`INSERT INTO dbtemp.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, - "2", "user", "user2@domain.com", 1, true, "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - if result, err := db.Exec(`UPDATE dbtemp.tbltemp SET grade=@1,active=$2,data=:3 WHERE id=$4`, - 2.0, false, `a string 'with' "quote"`, "1", "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := result.RowsAffected(); numRows != 1 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=1/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if result, err := db.Exec(`UPDATE dbtemp.tbltemp SET username=:1 WHERE id=$2`, - "user1", "1", "user1"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := result.RowsAffected(); numRows != 0 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=0/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if _, err := db.Exec(`UPDATE dbtemp.tbltemp SET email=:1 WHERE id=@2`, "user2@domain.com", "1", "user"); err != ErrConflict { - t.Fatalf("%s failed: %s", name, err) - } - - if _, err := db.Exec(`UPDATE dbtemp.tbl_not_found SET email=@1 WHERE id=1`, "user2@domain.com", "user"); err != ErrNotFound { - t.Fatalf("%s failed: %s", name, err) - } - - if _, err := db.Exec(`UPDATE db_not_exists.tbltemp SET email=$1 WHERE id=1`, "user2@domain.com", "user"); err != ErrNotFound { - t.Fatalf("%s failed: %s", name, err) - } - - // can not change document id - if result, err := db.Exec(`UPDATE dbtemp.tbltemp SET id=:1 WHERE id=1`, "0", "user"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else if id, err := result.LastInsertId(); id != 0 || err == nil { - t.Fatalf("%s failed: expected LastInsertId=0/err!=nil but received LastInsertId=%d/err=%s", name, id, err) - } else if numRows, err := result.RowsAffected(); numRows != 0 || err != nil { - t.Fatalf("%s failed: expected RowsAffected=0/err=nil but received RowsAffected=%d/err=%s", name, numRows, err) - } - - if _, err := db.Exec(`UPDATE dbtemp.tbltemp SET grade=10 WHERE id=$10`, "1", "user"); err == nil || strings.Index(err.Error(), "invalid value index") < 0 { - t.Fatalf("%s failed: expected 'invalid value index' but received '%s'", name, err) - } - - if _, err := db.Exec(`UPDATE dbtemp.tbltemp SET grade=$10 WHERE id=1`, "1", "user"); err == nil || strings.Index(err.Error(), "invalid value index") < 0 { - t.Fatalf("%s failed: expected 'invalid value index' but received '%s'", name, err) - } -} diff --git a/stmt.go b/stmt.go index cf49ff3..ef5ac66 100644 --- a/stmt.go +++ b/stmt.go @@ -270,6 +270,12 @@ func buildResultNoResultSet(restResponse *RestReponse, supportLastInsertId bool, } else { result.err = ErrConflict } + case 412: + if ignoreErrorCode == 412 { + result.err = nil + } else { + result.err = ErrPreconditionFailure + } } return result } diff --git a/stmt_document.go b/stmt_document.go index b368f38..cb219cb 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -235,15 +235,11 @@ func (s *StmtDelete) Exec(args []driver.Value) (driver.Result, error) { }) result := buildResultNoResultSet(&restResult.RestReponse, false, "", 0) switch restResult.StatusCode { - case 403: - result.err = ErrForbidden case 404: // consider "document not found" as successful operation // but database/collection not found is not! if strings.Index(fmt.Sprintf("%s", restResult.Error()), "ResourceType: Document") >= 0 { result.err = nil - } else { - result.err = ErrNotFound } } return result, result.err @@ -466,16 +462,9 @@ func (s *StmtUpdate) _parseId() error { if hasPrefix && hasSuffix { s.idStr = strings.TrimSpace(s.idStr[1 : len(s.idStr)-1]) } else if loc := reValPlaceholder.FindStringIndex(s.idStr); loc != nil { - // if loc[0] == 0 && loc[1] == len(s.idStr) { index, _ := strconv.Atoi(s.idStr[loc[0]+1:]) - // if err != nil || index < 1 { - // return fmt.Errorf("invalid id placeholder literate: %s", s.idStr) - // } s.id = placeholder{index} s.numInput++ - // } else { - // return fmt.Errorf("invalid id literate: %s", s.idStr) - // } } return nil } @@ -545,16 +534,12 @@ func (s *StmtUpdate) validate() error { if len(s.fields) == 0 { return errors.New("invalid query: SET clause is empty") } - // if len(s.fields) != len(s.values) { - // return fmt.Errorf("number of field (%d) does not match number of input value (%d)", len(s.fields), len(s.values)) - // } return nil } -// Exec implements driver.Stmt.Exec. -// Upon successful call, this function returns (*ResultUpdate, nil). +// Exec implements driver.Stmt/Exec. // -// Note: this function expects the last argument is partition key value. +// Note: this function expects the _last_ argument is _partition_ key value. func (s *StmtUpdate) Exec(args []driver.Value) (driver.Result, error) { // firstly, fetch the document id := s.idStr @@ -568,15 +553,15 @@ func (s *StmtUpdate) Exec(args []driver.Value) (driver.Result, error) { docReq := DocReq{DbName: s.dbName, CollName: s.collName, DocId: id, PartitionKeyValues: []interface{}{args[len(args)-1]}} getDocResult := s.conn.restClient.GetDocument(docReq) if err := getDocResult.Error(); err != nil { + result := buildResultNoResultSet(&getDocResult.RestReponse, false, "", 0) if getDocResult.StatusCode == 404 { // consider "document not found" as successful operation // but database/collection not found is not! if strings.Index(fmt.Sprintf("%s", err), "ResourceType: Document") >= 0 { - return &ResultUpdate{Successful: false}, nil + result.err = nil } - return nil, ErrNotFound } - return nil, getDocResult.Error() + return result, result.err } etag := getDocResult.DocInfo.Etag() spec := DocumentSpec{DbName: s.dbName, CollName: s.collName, PartitionKeyValues: []interface{}{args[len(args)-1]}, DocumentData: getDocResult.DocInfo.RemoveSystemAttrs()} @@ -593,48 +578,20 @@ func (s *StmtUpdate) Exec(args []driver.Value) (driver.Result, error) { } } replaceDocResult := s.conn.restClient.ReplaceDocument(etag, spec) - result := &ResultUpdate{Successful: replaceDocResult.Error() == nil} - err := replaceDocResult.Error() + result := buildResultNoResultSet(&replaceDocResult.RestReponse, false, "", 412) switch replaceDocResult.StatusCode { - case 403: - err = ErrForbidden - case 404: // race case, but possible + case 404: // rare case, but possible! // consider "document not found" as successful operation // but database/collection not found is not! - if strings.Index(fmt.Sprintf("%s", err), "ResourceType: Document") >= 0 { - err = nil - } else { - err = ErrNotFound + if strings.Index(fmt.Sprintf("%s", replaceDocResult.Error()), "ResourceType: Document") >= 0 { + result.err = nil } - case 409: - err = ErrConflict - case 412: - err = nil } - return result, err + return result, result.err } -// Query implements driver.Stmt.Query. +// Query implements driver.Stmt/Query. // This function is not implemented, use Exec instead. -func (s *StmtUpdate) Query(args []driver.Value) (driver.Rows, error) { - return nil, errors.New("this operation is not supported, please use exec") -} - -// ResultUpdate captures the result from UPDATE operation. -type ResultUpdate struct { - // Successful flags if the operation was successful or not. - Successful bool -} - -// LastInsertId implements driver.Result.LastInsertId. -func (r *ResultUpdate) LastInsertId() (int64, error) { - return 0, errors.New("this operation is not supported") -} - -// RowsAffected implements driver.Result.RowsAffected. -func (r *ResultUpdate) RowsAffected() (int64, error) { - if r.Successful { - return 1, nil - } - return 0, nil +func (s *StmtUpdate) Query(_ []driver.Value) (driver.Rows, error) { + return nil, ErrQueryNotSupported } diff --git a/stmt_document_test.go b/stmt_document_test.go index ad18962..36091ff 100644 --- a/stmt_document_test.go +++ b/stmt_document_test.go @@ -858,3 +858,322 @@ func TestStmtDelete_Exec_DefaultDb(t *testing.T) { }) } } + +func TestStmtUpdate_Query(t *testing.T) { + testName := "TestStmtUpdate_Query" + db := _openDb(t, testName) + _, err := db.Query("UPDATE db.table SET a=1 WHERE id=2", nil) + if err != ErrQueryNotSupported { + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) + } +} + +func TestStmtUpdate_Exec(t *testing.T) { + testName := "TestStmtUpdate_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + initParams [][]interface{} + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "update_1", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain.com", 1, true, "user"}, {"2", "user", "user2@domain.com", 1, true, "user"}}, + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=2.0,active=false,data="\"a string 'with' \\\"quote\\\"\"" WHERE id=1`, dbname), + args: []interface{}{"user"}, + affectedRows: 1, + }, + { + name: "update_pk", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET username="\"user1\"" WHERE id=1`, dbname), + args: []interface{}{"user1"}, + affectedRows: 0, + }, + { + name: "error_uk", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET email="\"user2@domain.com\"" WHERE id=1`, dbname), + args: []interface{}{"user"}, + mustConflict: true, + }, + { + name: "row_not_exists", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=3.4 WHERE id=3`, dbname), + args: []interface{}{"user"}, + affectedRows: 0, + }, + { + name: "row_not_exists_in_partition", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=5.6 WHERE id=2`, dbname), + args: []interface{}{"user2"}, + affectedRows: 0, + }, + { + name: "table_not_exists", + sql: fmt.Sprintf(`UPDATE %s.tbl_not_found SET email="\"user2@domain.com\"" WHERE id=1`, dbname), + args: []interface{}{"user"}, + mustNotFound: true, + }, + { + name: "db_not_exists", + sql: `UPDATE db_not_exists.tbltemp SET email="\"user2@domain.com\"" WHERE id=1`, + args: []interface{}{"user"}, + mustNotFound: true, + }, + { + name: "update_1_placeholders", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain.com", 1, true, "user"}, {"2", "user", "user2@domain.com", 1, true, "user"}}, + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=:1,active=@2,data=$3 WHERE id=:4`, dbname), + args: []interface{}{2.0, false, "a string 'with' \"quote\"", "1", "user"}, + affectedRows: 1, + }, + { + name: "update_pk_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET username=$1 WHERE id=:2`, dbname), + args: []interface{}{"user1", "1", "user1"}, + affectedRows: 0, + }, + { + name: "error_uk_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET email=@1 WHERE id=:2`, dbname), + args: []interface{}{"user2@domain.com", "1", "user"}, + mustConflict: true, + }, + { + name: "row_not_exists_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=$1 WHERE id=:2`, dbname), + args: []interface{}{3.4, "3", "user"}, + affectedRows: 0, + }, + { + name: "row_not_exists_in_partition_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=@1 WHERE id=:2`, dbname), + args: []interface{}{5.6, "2", "user2"}, + affectedRows: 0, + }, + { + name: "table_not_exists_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbl_not_found SET email=:1 WHERE id=:2`, dbname), + args: []interface{}{"user2@domain.com", "1", "user"}, + mustNotFound: true, + }, + { + name: "db_not_exists_placeholders", + sql: `UPDATE db_not_exists.tbltemp SET email=:1 WHERE id=:2`, + args: []interface{}{"user2@domain.com", "1", "user"}, + mustNotFound: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for i, initSql := range testCase.initSqls { + var params []interface{} + if len(testCase.initParams) > i { + params = testCase.initParams[i] + } + _, err := db.Exec(initSql, params...) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + }) + } +} + +func TestStmtUpdate_Exec_DefaultDb(t *testing.T) { + testName := "TestStmtUpdate_Exec_DefaultDb" + dbname := "dbdefault" + db := _openDefaultDb(t, testName, dbname) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + initParams [][]interface{} + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "update_1", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain.com", 1, true, "user"}, {"2", "user", "user2@domain.com", 1, true, "user"}}, + sql: `UPDATE tbltemp SET grade=2.0,active=false,data="\"a string 'with' \\\"quote\\\"\"" WHERE id=1`, + args: []interface{}{"user"}, + affectedRows: 1, + }, + { + name: "update_pk", + sql: `UPDATE tbltemp SET username="\"user1\"" WHERE id=1`, + args: []interface{}{"user1"}, + affectedRows: 0, + }, + { + name: "error_uk", + sql: `UPDATE tbltemp SET email="\"user2@domain.com\"" WHERE id=1`, + args: []interface{}{"user"}, + mustConflict: true, + }, + { + name: "row_not_exists", + sql: `UPDATE tbltemp SET grade=3.4 WHERE id=3`, + args: []interface{}{"user"}, + affectedRows: 0, + }, + { + name: "row_not_exists_in_partition", + sql: `UPDATE tbltemp SET grade=5.6 WHERE id=2`, + args: []interface{}{"user2"}, + affectedRows: 0, + }, + { + name: "table_not_exists", + sql: `UPDATE tbl_not_found SET email="\"user2@domain.com\"" WHERE id=1`, + args: []interface{}{"user"}, + mustNotFound: true, + }, + { + name: "update_1_placeholders", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain.com", 1, true, "user"}, {"2", "user", "user2@domain.com", 1, true, "user"}}, + sql: `UPDATE tbltemp SET grade=:1,active=@2,data=$3 WHERE id=:4`, + args: []interface{}{2.0, false, "a string 'with' \"quote\"", "1", "user"}, + affectedRows: 1, + }, + { + name: "update_pk_placeholders", + sql: `UPDATE tbltemp SET username=$1 WHERE id=:2`, + args: []interface{}{"user1", "1", "user1"}, + affectedRows: 0, + }, + { + name: "error_uk_placeholders", + sql: `UPDATE tbltemp SET email=@1 WHERE id=:2`, + args: []interface{}{"user2@domain.com", "1", "user"}, + mustConflict: true, + }, + { + name: "row_not_exists_placeholders", + sql: `UPDATE tbltemp SET grade=$1 WHERE id=:2`, + args: []interface{}{3.4, "3", "user"}, + affectedRows: 0, + }, + { + name: "row_not_exists_in_partition_placeholders", + sql: `UPDATE tbltemp SET grade=@1 WHERE id=:2`, + args: []interface{}{5.6, "2", "user2"}, + affectedRows: 0, + }, + { + name: "table_not_exists_placeholders", + sql: `UPDATE tbl_not_found SET email=:1 WHERE id=:2`, + args: []interface{}{"user2@domain.com", "1", "user"}, + mustNotFound: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for i, initSql := range testCase.initSqls { + var params []interface{} + if len(testCase.initParams) > i { + params = testCase.initParams[i] + } + _, err := db.Exec(initSql, params...) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + }) + } +} From 0d80cd9eb602aa381f88e0cd420d42f4f8bcf1c5 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Thu, 8 Jun 2023 12:30:24 +1000 Subject: [PATCH 10/13] refactor SELECT and update tests --- .github/workflows/gocosmos.yaml | 26 ++ RELEASE-NOTES.md | 4 + conn.go | 4 +- driver.go | 8 +- driver_test.go | 89 ++++ gocosmos_select_test.go | 774 -------------------------------- gocosmos_test.go | 95 ---- stmt.go | 73 +-- stmt_collection.go | 4 +- stmt_database.go | 4 +- stmt_document.go | 93 +--- stmt_document_select_test.go | 632 ++++++++++++++++++++++++++ stmt_test.go | 27 ++ 13 files changed, 851 insertions(+), 982 deletions(-) create mode 100644 driver_test.go create mode 100644 stmt_document_select_test.go diff --git a/.github/workflows/gocosmos.yaml b/.github/workflows/gocosmos.yaml index 7545a87..b480a51 100644 --- a/.github/workflows/gocosmos.yaml +++ b/.github/workflows/gocosmos.yaml @@ -104,6 +104,32 @@ jobs: flags: driver_document_non_query name: driver_document_non_query + testDriverStmtDocumentQuery: + name: Test driver document query statements + runs-on: windows-latest + steps: + - name: Set up Go env + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + - name: Test + run: | + choco install azure-cosmosdb-emulator + & "C:\Program Files\Azure Cosmos DB Emulator\Microsoft.Azure.Cosmos.Emulator.exe" /DisableRateLimiting /NoUI /NoExplorer + Start-Sleep -s 60 + try { Invoke-RestMethod -Method GET https://127.0.0.1:8081/ } catch {} + netstat -nt + $env:COSMOSDB_DRIVER='gocosmos' + $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' + go test -cover -coverprofile="coverage_driver_document_query.txt" -v -count 1 -p 1 -run "TestStmtSelect_(Exec|Query)" . + - name: Codecov + uses: codecov/codecov-action@v3 + with: + flags: driver_document_query + name: driver_document_query + testDriverSelect: name: Test driver SELECT query runs-on: windows-latest diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 98b0176..deacbce 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,5 +1,9 @@ # gocosmos release notes +## 2023-06-xx - v0.2.1 + +- Bug fixes, Refactoring & Enhancements. + ## 2023-03-14 - v0.2.0 - `RestClient`: diff --git a/conn.go b/conn.go index 7bd0e8a..1c38789 100644 --- a/conn.go +++ b/conn.go @@ -24,7 +24,7 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { // PrepareContext implements driver.ConnPrepareContext/PrepareContext. // -// @Available since v0.3.0 +// @Available since v0.2.1 func (c *Conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) { return parseQueryWithDefaultDb(c, c.defaultDb, query) } @@ -41,7 +41,7 @@ func (c *Conn) Begin() (driver.Tx, error) { // BeginTx implements driver.ConnBeginTx/BeginTx. // -// @Available since v0.3.0 +// @Available since v0.2.1 func (c *Conn) BeginTx(_ context.Context, _ driver.TxOptions) (driver.Tx, error) { return nil, errors.New("transaction is not supported") } diff --git a/driver.go b/driver.go index 315e02a..5446cd1 100644 --- a/driver.go +++ b/driver.go @@ -69,22 +69,22 @@ var ( // ErrPreconditionFailure is returned when operation specified an eTag that is different from the version available // at the server, that is, an optimistic concurrency error. // - // @Available since v0.3.0 + // @Available since v0.2.1 ErrPreconditionFailure = errors.New("StatusCode=412 Precondition failure") // ErrOperationNotSupported is returned to indicate that the operation is not supported. // - // @Available since v0.3.0 + // @Available since v0.2.1 ErrOperationNotSupported = errors.New("this operation is not supported") // ErrExecNotSupported is returned to indicate that the Exec/ExecContext operation is not supported. // - // @Available since v0.3.0 + // @Available since v0.2.1 ErrExecNotSupported = errors.New("this operation is not supported, please use Query") // ErrQueryNotSupported is returned to indicate that the Query/QueryContext operation is not supported. // - // @Available since v0.3.0 + // @Available since v0.2.1 ErrQueryNotSupported = errors.New("this operation is not supported, please use Exec") ) diff --git a/driver_test.go b/driver_test.go new file mode 100644 index 0000000..edd67c0 --- /dev/null +++ b/driver_test.go @@ -0,0 +1,89 @@ +package gocosmos + +import ( + "context" + "database/sql" + "strings" + "testing" +) + +func TestDriver_invalidConnectionString(t *testing.T) { + name := "TestDriver_invalidConnectionString" + driver := "gocosmos" + { + db, _ := sql.Open(driver, "AccountEndpoint;AccountKey=demo") + if err := db.Ping(); err == nil { + t.Fatalf("%s failed: should have error", name) + } + } + { + db, _ := sql.Open(driver, "AccountEndpoint=demo;AccountKey") + if err := db.Ping(); err == nil { + t.Fatalf("%s failed: should have error", name) + } + } + { + db, _ := sql.Open(driver, "AccountEndpoint=demo;AccountKey=demo/invalid_key") + if err := db.Ping(); err == nil { + t.Fatalf("%s failed: should have error", name) + } + } +} + +func TestDriver_missingEndpoint(t *testing.T) { + name := "TestDriver_missingEndpoint" + driver := "gocosmos" + dsn := "AccountKey=demo" + db, _ := sql.Open(driver, dsn) + if err := db.Ping(); err == nil { + t.Fatalf("%s failed: should have error", name) + } +} + +func TestDriver_missingAccountKey(t *testing.T) { + name := "TestDriver_missingAccountKey" + driver := "gocosmos" + dsn := "AccountEndpoint=demo" + db, _ := sql.Open(driver, dsn) + if err := db.Ping(); err == nil { + t.Fatalf("%s failed: should have error", name) + } +} + +func TestDriver_Conn(t *testing.T) { + name := "TestDriver_Conn" + db := _openDb(t, name) + _, err := db.Conn(context.Background()) + if err != nil { + t.Fatalf("%s failed: %s", name, err) + } +} + +func TestDriver_Transaction(t *testing.T) { + name := "TestDriver_Transaction" + db := _openDb(t, name) + if tx, err := db.BeginTx(context.Background(), nil); tx != nil || err == nil { + t.Fatalf("%s failed: transaction is not supported yet", name) + } else if strings.Index(err.Error(), "not supported") < 0 { + t.Fatalf("%s failed: transaction is not supported yet / %s", name, err) + } +} + +func TestDriver_Open(t *testing.T) { + name := "TestDriver_Open" + db := _openDb(t, name) + if err := db.Ping(); err != nil { + t.Fatalf("%s failed: %s", name, err) + } +} + +func TestDriver_Close(t *testing.T) { + name := "TestDriver_Close" + db := _openDb(t, name) + if err := db.Ping(); err != nil { + t.Fatalf("%s failed: %s", name, err) + } + if err := db.Close(); err != nil { + t.Fatalf("%s failed: %s", name, err) + } +} diff --git a/gocosmos_select_test.go b/gocosmos_select_test.go index 466fbf6..a3d327b 100644 --- a/gocosmos_select_test.go +++ b/gocosmos_select_test.go @@ -2,436 +2,9 @@ package gocosmos import ( "database/sql" - "encoding/json" - "fmt" - "reflect" - "sort" - "strconv" - "strings" - "sync" "testing" - - "github.com/btnguyen2k/consu/reddo" ) -func Test_Exec_Select(t *testing.T) { - name := "Test_Exec_Select" - db := _openDb(t, name) - _, err := db.Exec("SELECT * FROM c WITH db=db WITH collection=table") - if err == nil || strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: expected 'not support' error, but received %#v", name, err) - } -} - -/*----------------------------------------------------------------------*/ - -func _fetchAllRows(dbRows *sql.Rows) ([]map[string]interface{}, error) { - colTypes, err := dbRows.ColumnTypes() - if err != nil { - return nil, err - } - numCols := len(colTypes) - rows := make([]map[string]interface{}, 0) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - rows = append(rows, row) - } else if err != sql.ErrNoRows { - return nil, err - } - } - return rows, nil -} - -/*----------------------------------------------------------------------*/ - -func _testSelectPkValue(t *testing.T, testName string, db *sql.DB, collname string) { - low, high := 123, 987 - lowStr, highStr := fmt.Sprintf("%05d", low), fmt.Sprintf("%05d", high) - countPerPartition := _countPerPartition(low, high, dataList) - distinctPerPartition := _distinctPerPartition(low, high, dataList, "category") - var testCases = []queryTestCase{ - {name: "NoLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true"}, - {name: "OffsetLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5}, - {name: "NoLimit_OrderAsc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.grade WITH collection=%s WITH cross_partition=true", orderType: reddo.TypeInt, orderField: "grade", orderDirection: "asc"}, - {name: "OffsetLimit_OrderDesc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc"}, - - {name: "NoLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1}, - {name: "NoLimit_DistinctDoc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1}, - {name: "OffsetLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, expectedNumItems: 3}, - {name: "OffsetLimit_DistinctDoc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, expectedNumItems: 3}, - - {name: "NoLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc"}, - {name: "NoLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc"}, - {name: "OffsetLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc", expectedNumItems: 3}, - {name: "OffsetLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc", expectedNumItems: 3}, - - /* GROUP BY with ORDER BY is not supported! */ - {name: "NoLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "count"}, - {name: "OffsetLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "count"}, - {name: "NoLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "sum"}, - {name: "OffsetLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "sum"}, - {name: "NoLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "min"}, - {name: "OffsetLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "min"}, - {name: "NoLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "max"}, - {name: "OffsetLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "max"}, - {name: "NoLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "average"}, - {name: "OffsetLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "average"}, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - savedExpectedNumItems := testCase.expectedNumItems - for i := 0; i < numLogicalPartitions; i++ { - testCase.expectedNumItems = savedExpectedNumItems - expectedNumItems := testCase.expectedNumItems - username := "user" + strconv.Itoa(i) - params := []interface{}{lowStr, highStr, username} - if expectedNumItems <= 0 && testCase.maxItemCount <= 0 { - expectedNumItems = countPerPartition[username] - if testCase.distinctQuery != 0 { - expectedNumItems = distinctPerPartition[username] - } - testCase.expectedNumItems = expectedNumItems - } - sql := fmt.Sprintf(testCase.query, collname) - dbRows, err := db.Query(sql, params...) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - rows, err := _fetchAllRows(dbRows) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, expectedNumItems, rows) - _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, rows) - _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, rows) - _verifyGroupBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, username, lowStr, highStr, rows) - } - }) - } -} - -func TestSelect_PkValue_SmallRU(t *testing.T) { - testName := "TestSelect_PkValue_SmallRU" - dbname := testDb - collname := testTable - client := _newRestClient(t, testName) - _initDataSmallRU(t, testName, client, dbname, collname, 1000) - if result := client.GetPkranges(dbname, collname); result.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) - } else if result.Count != 1 { - t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectPkValue(t, testName, db, collname) -} - -func TestSelect_PkValue_LargeRU(t *testing.T) { - testName := "TestSelect_PkValue_LargeRU" - dbname := testDb - collname := testTable - client := _newRestClient(t, testName) - _initDataLargeRU(t, testName, client, dbname, collname, 1000) - if result := client.GetPkranges(dbname, collname); result.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) - } else if result.Count < 2 { - t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectPkValue(t, testName, db, collname) -} - -/*----------------------------------------------------------------------*/ - -func _testSelectCrossPartition(t *testing.T, testName string, db *sql.DB, collname string) { - low, high := 123, 987 - lowStr, highStr := fmt.Sprintf("%05d", low), fmt.Sprintf("%05d", high) - var testCases = []queryTestCase{ - {name: "NoLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 WITH collection=%s WITH cross_partition=true"}, - {name: "OffsetLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5}, - {name: "NoLimit_OrderAsc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.grade WITH collection=%s WITH cross_partition=true", orderType: reddo.TypeInt, orderField: "grade", orderDirection: "asc"}, - {name: "OffsetLimit_OrderDesc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.category DESC OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc"}, - - {name: "NoLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, expectedNumItems: numCategories}, - {name: "NoLimit_DistinctDoc", query: "SELECT DISTINCT c.username FROM c WHERE $1<=c.id AND c.id<@2 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, expectedNumItems: numLogicalPartitions}, - {name: "OffsetLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.username FROM c WHERE $1<=c.id AND c.id<@2 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, expectedNumItems: 3}, - {name: "OffsetLimit_DistinctDoc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, expectedNumItems: 3}, - - {name: "NoLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.category WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc", expectedNumItems: numCategories}, - {name: "NoLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.username FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.username DESC WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeString, orderField: "username", orderDirection: "desc", expectedNumItems: numLogicalPartitions}, - {name: "OffsetLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc", expectedNumItems: 3}, - {name: "OffsetLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.username FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.username DESC OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeString, orderField: "username", orderDirection: "desc", expectedNumItems: 3}, - - /* GROUP BY with ORDER BY is not supported! */ - - {name: "NoLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "count"}, - {name: "OffsetLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "count"}, - {name: "NoLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "sum"}, - {name: "OffsetLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "sum"}, - {name: "NoLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "min"}, - {name: "OffsetLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "min"}, - {name: "NoLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "max"}, - {name: "OffsetLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "max"}, - {name: "NoLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "average"}, - {name: "OffsetLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "average"}, - } - params := []interface{}{lowStr, highStr} - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - expectedNumItems := high - low - if testCase.expectedNumItems > 0 { - expectedNumItems = testCase.expectedNumItems - } - sql := fmt.Sprintf(testCase.query, collname) - dbRows, err := db.Query(sql, params...) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - rows, err := _fetchAllRows(dbRows) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, expectedNumItems, rows) - _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, rows) - _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, rows) - _verifyGroupBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, "", lowStr, highStr, rows) - }) - } -} - -func TestSelect_CrossPartition_SmallRU(t *testing.T) { - testName := "TestSelect_CrossPartition_SmallRU" - dbname := testDb - collname := testTable - client := _newRestClient(t, testName) - _initDataSmallRU(t, testName, client, dbname, collname, 1000) - if result := client.GetPkranges(dbname, collname); result.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) - } else if result.Count != 1 { - t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectCrossPartition(t, testName, db, collname) -} - -func TestSelect_CrossPartition_LargeRU(t *testing.T) { - testName := "TestSelect_CrossPartition_LargeRU" - dbname := testDb - collname := testTable - client := _newRestClient(t, testName) - _initDataLargeRU(t, testName, client, dbname, collname, 1000) - if result := client.GetPkranges(dbname, collname); result.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) - } else if result.Count < 2 { - t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectCrossPartition(t, testName, db, collname) -} - -/*----------------------------------------------------------------------*/ - -func _testSelectPaging(t *testing.T, testName string, db *sql.DB, collname string, pkranges *RespGetPkranges) { - low, high := 123, 987 - lowStr, highStr := fmt.Sprintf("%05d", low), fmt.Sprintf("%05d", high) - var testCases = []queryTestCase{ - {name: "Simple_OrderAsc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.id OFFSET :3 LIMIT 23 WITH collection=%s WITH cross_partition=true", maxItemCount: 23, orderField: "id", orderType: reddo.TypeString, orderDirection: "asc"}, - {name: "Simple_OrderDesc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.id DESC OFFSET :3 LIMIT 29 WITH collection=%s WITH cross_partition=true", maxItemCount: 29, orderField: "id", orderType: reddo.TypeString, orderDirection: "desc"}, - - {name: "DistinctDoc_OrderAsc", query: "SELECT DISTINCT c.username FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.username OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, orderField: "username", orderType: reddo.TypeString, orderDirection: "asc", expectedNumItems: numLogicalPartitions, distinctQuery: -1, distinctField: "username"}, - {name: "DistinctValue_OrderDesc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.category DESC OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, orderField: "$1", orderType: reddo.TypeInt, orderDirection: "desc", expectedNumItems: numCategories, distinctQuery: 1, distinctField: "$1"}, - - {name: "GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "count", expectedNumItems: numCategories}, - {name: "GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "sum", expectedNumItems: numCategories}, - {name: "GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "min", expectedNumItems: numCategories}, - {name: "GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "max", expectedNumItems: numCategories}, - {name: "GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "average", expectedNumItems: numCategories}, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - expectedNumItems := high - low - if testCase.expectedNumItems > 0 { - expectedNumItems = testCase.expectedNumItems - } - sql := fmt.Sprintf(testCase.query, collname) - offset := 0 - finalRows := make([]map[string]interface{}, 0) - for { - params := []interface{}{lowStr, highStr, offset} - dbRows, err := db.Query(sql, params...) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - rows, err := _fetchAllRows(dbRows) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - if offset == 0 || len(rows) != 0 { - _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, 0, rows) - _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, rows) - _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, rows) - } - if len(rows) == 0 { - break - } - finalRows = append(finalRows, rows...) - offset += len(rows) - } - testCase.maxItemCount = 0 - // { - // for i, row := range finalRows { - // fmt.Printf("%5d: %s\n", i, row["id"]) - // } - // } - _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, expectedNumItems, finalRows) - _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, finalRows) - _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, finalRows) - _verifyGroupBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, "", lowStr, highStr, finalRows) - }) - } -} - -func TestSelect_Paging_SmallRU(t *testing.T) { - testName := "TestSelect_Paging_SmallRU" - dbname := testDb - collname := testTable - client := _newRestClient(t, testName) - _initDataSmallRU(t, testName, client, dbname, collname, 1000) - pkranges := client.GetPkranges(dbname, collname) - if pkranges.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", pkranges.Error()) - } else if pkranges.Count != 1 { - t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, pkranges.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectPaging(t, testName, db, collname, pkranges) -} - -func TestSelect_Paging_LargeRU(t *testing.T) { - testName := "TestSelect_Paging_LargeRU" - dbname := testDb - collname := testTable - client := _newRestClient(t, testName) - _initDataLargeRU(t, testName, client, dbname, collname, 1000) - pkranges := client.GetPkranges(dbname, collname) - if pkranges.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", pkranges.Error()) - } else if pkranges.Count < 2 { - t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, pkranges.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectPaging(t, testName, db, collname, pkranges) -} - -/*----------------------------------------------------------------------*/ - -func _testSelectCustomDataset(t *testing.T, testName string, testCases []customQueryTestCase, db *sql.DB, collname string) { - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - sql := fmt.Sprintf(testCase.query, collname) - dbRows, err := db.Query(sql) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - rows, err := _fetchAllRows(dbRows) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - - var expectedResult []interface{} - json.Unmarshal([]byte(testCase.expectedResultJson), &expectedResult) - if len(rows) != len(expectedResult) { - t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/"+testCase.name, len(expectedResult), len(rows)) - } - if !testCase.ordering { - sort.Slice(rows, func(i, j int) bool { - var doci, docj = rows[i], rows[j] - stri, _ := json.Marshal(doci[testCase.compareField]) - strj, _ := json.Marshal(docj[testCase.compareField]) - return string(stri) < string(strj) - }) - sort.Slice(expectedResult, func(i, j int) bool { - var doci, docj = expectedResult[i].(map[string]interface{}), expectedResult[j].(map[string]interface{}) - stri, _ := json.Marshal(doci[testCase.compareField]) - strj, _ := json.Marshal(docj[testCase.compareField]) - return string(stri) < string(strj) - }) - } - for i, row := range rows { - expected := expectedResult[i] - if !reflect.DeepEqual(row, expected) { - // fmt.Printf("DEBUG: %#v\n", rows) - // fmt.Printf("DEBUG: %#v\n", expectedResult) - t.Fatalf("%s failed: result\n%#v\ndoes not match expected one\n%#v", testName+"/"+testCase.name, row, expected) - } - } - }) - } -} - -func _testSelectDatasetFamilies(t *testing.T, testName string, db *sql.DB, collname string) { - var testCases = []customQueryTestCase{ - // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/getting-started - // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/select - // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/from - // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/order-by - // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/group-by - // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/offset-limit - {name: "QuerySingleDoc", compareField: "id", query: `SELECT * FROM Families f WHERE f.id = "AndersenFamily" WITH collection=%s WITH cross_partition=true`, expectedResultJson: _toJson([]DocInfo{dataMapFamilies["AndersenFamily"]})}, - {name: "QuerySingleAttr", compareField: "id", query: `SELECT f.address FROM Families f WHERE f.id = "AndersenFamily" WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"address":{"state":"WA","county":"King","city":"Seattle"}}]`}, - {name: "QuerySubAttrs", compareField: "id", query: `SELECT {"Name":f.id, "City":f.address.city} AS Family FROM Families f WHERE f.address.city = f.address.state WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"Family":{"Name":"WakefieldFamily","City":"NY"}}]`}, - {name: "QuerySubItems1", compareField: "$1", query: `SELECT * FROM Families.children WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"$1":[{"firstName":"Henriette Thaulow","gender":"female","grade":5,"pets":[{"givenName":"Fluffy"}]}]},{"$1":[{"familyName":"Merriam","gender":"female","givenName":"Jesse","grade":1,"pets":[{"givenName":"Goofy"},{"givenName":"Shadow"}]},{"familyName":"Miller","gender":"female","givenName":"Lisa","grade":8}]}]`}, - {name: "QuerySubItems2", compareField: "$1", query: `SELECT * FROM Families.address.state WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"$1":"WA"},{"$1":"NY"}]`}, - {name: "QuerySingleAttrWithOrderBy", ordering: true, query: `SELECT c.givenName FROM Families f JOIN c IN f.children WHERE f.id = 'WakefieldFamily' ORDER BY f.address.city ASC WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"givenName":"Jesse"},{"givenName":"Lisa"}]`}, - {name: "QuerySubAttrsWithOrderByAsc", ordering: true, query: `SELECT f.id, f.address.city FROM Families f ORDER BY f.address.city WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"WakefieldFamily","city":"NY"},{"id":"AndersenFamily","city":"Seattle"}]`}, - {name: "QuerySubAttrsWithOrderByDesc", ordering: true, query: `SELECT f.id, f.creationDate FROM Families f ORDER BY f.creationDate DESC WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"AndersenFamily","creationDate":1431620472},{"id":"WakefieldFamily","creationDate":1431620462}]`}, - {name: "QuerySubAttrsWithOrderByMissingField", ordering: false, query: `SELECT f.id, f.lastName FROM Families f ORDER BY f.lastName WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"WakefieldFamily","lastName":null},{"id":"AndersenFamily","lastName":"Andersen"}]`}, - {name: "QueryGroupBy", compareField: "$1", query: `SELECT COUNT(UniqueLastNames) FROM (SELECT AVG(f.age) FROM f GROUP BY f.lastName) AS UniqueLastNames WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"$1":2}]`}, - {name: "QueryOffsetLimitWithOrderBy", compareField: "id", query: `SELECT f.id, f.address.city FROM Families f ORDER BY f.address.city OFFSET 1 LIMIT 1 WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"AndersenFamily","city":"Seattle"}]`}, - // without ORDER BY, the returned result is un-deterministic - // {name: "QueryOffsetLimitWithoutOrderBy", compareField: "id", query: `SELECT f.id, f.address.city FROM Families f OFFSET 1 LIMIT 1 WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"AndersenFamily","city":"Seattle"}]`}, - } - _testSelectCustomDataset(t, testName, testCases, db, collname) -} - -func TestSelect_DatasetFamilies_SmallRU(t *testing.T) { - testName := "TestSelect_DatasetFamilies_SmallRU" - client := _newRestClient(t, testName) - dbname := testDb - collname := testTable - _initDataFamliesSmallRU(t, testName, client, dbname, collname) - if result := client.GetPkranges(dbname, collname); result.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) - } else if result.Count != 1 { - t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectDatasetFamilies(t, testName, db, collname) -} - -func TestSelect_DatasetFamilies_LargeRU(t *testing.T) { - testName := "TestSelect_DatasetFamilies_LargeRU" - client := _newRestClient(t, testName) - dbname := testDb - collname := testTable - _initDataFamliesLargeRU(t, testName, client, dbname, collname) - if result := client.GetPkranges(dbname, collname); result.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) - } else if result.Count < 2 { - t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectDatasetFamilies(t, testName, db, collname) -} - func _testSelectDatasetNutrition(t *testing.T, testName string, db *sql.DB, collname string) { var testCases = []customQueryTestCase{ // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/group-by @@ -475,350 +48,3 @@ func TestSelect_DatasetNutrition_LargeRU(t *testing.T) { db := _openDefaultDb(t, testName, dbname) _testSelectDatasetNutrition(t, testName, db, collname) } - -/*----------------------------------------------------------------------*/ - -func Test_Query_Select(t *testing.T) { - name := "Test_Query_Select" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - for i := 0; i < 100; i++ { - id := fmt.Sprintf("%02d", i) - username := "user" + strconv.Itoa(i%4) - db.Exec("INSERT INTO dbtemp.tbltemp (id,username,email,grade) VALUES (:1,@2,$3,:4)", id, username, "user"+id+"@domain.com", i, username) - } - - if dbRows, err := db.Query(`SELECT * FROM c WHERE c.username="user0" AND c.id>"30" ORDER BY c.id WITH database=dbtemp WITH collection=tbltemp`); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else { - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - rows := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["id"]) - rows[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - if len(rows) != 17 { - t.Fatalf("%s failed: expected %#v but received %#v", name, 17, len(rows)) - } - for k := range rows { - if k <= "30" { - t.Fatalf("%s failed: document #%s should not be returned", name, k) - } - } - } - - if dbRows, err := db.Query(`SELECT CROSS PARTITION * FROM tbltemp c WHERE c.username>"user1" AND c.id>"53" WITH database=dbtemp`); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else { - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - rows := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["id"]) - rows[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - if len(rows) != 24 { - t.Fatalf("%s failed: expected %#v but received %#v", name, 24, len(rows)) - } - for k := range rows { - if k <= "53" { - t.Fatalf("%s failed: document #%s should not be returned", name, k) - } - } - } - - if _, err := db.Query(`SELECT * FROM c WITH db=dbtemp WITH collection=tbl_not_found`); err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } - - if _, err := db.Query(`SELECT * FROM c WITH db=db_not_found WITH collection=tbltemp`); err != ErrNotFound { - t.Fatalf("%s failed: expected ErrNotFound but received %#v", name, err) - } -} - -func Test_Query_SelectLongList(t *testing.T) { - name := "Test_Query_SelectLongList" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - for i := 0; i < 1000; i++ { - id := fmt.Sprintf("%03d", i) - username := "user" + strconv.Itoa(i%4) - db.Exec("INSERT INTO dbtemp.tbltemp (id,username,email,grade) VALUES (:1,@2,$3,:4)", id, username, "user"+id+"@domain.com", i, username) - } - - if dbRows, err := db.Query(`SELECT * FROM c WHERE c.username="user0" AND c.id>"030" ORDER BY c.id WITH database=dbtemp WITH collection=tbltemp`); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else { - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - rows := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["id"]) - rows[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - if len(rows) != 242 { - t.Fatalf("%s failed: expected %#v but received %#v", name, 242, len(rows)) - } - for k := range rows { - if k <= "030" { - t.Fatalf("%s failed: document #%s should not be returned", name, k) - } - } - } -} - -func Test_Query_SelectPlaceholder(t *testing.T) { - name := "Test_Query_SelectPlaceholder" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS db_not_exists") - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - for i := 0; i < 100; i++ { - id := fmt.Sprintf("%02d", i) - username := "user" + strconv.Itoa(i%4) - db.Exec("INSERT INTO dbtemp.tbltemp (id,username,email,grade) VALUES (:1,@2,$3,:4)", id, username, "user"+id+"@domain.com", i, username) - } - - if dbRows, err := db.Query(`SELECT * FROM c WHERE c.username=$2 AND c.id>:1 ORDER BY c.id WITH database=dbtemp WITH collection=tbltemp`, "30", "user0"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else { - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - rows := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["id"]) - rows[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - if len(rows) != 17 { - t.Fatalf("%s failed: expected %#v but received %#v", name, 17, len(rows)) - } - for k := range rows { - if k <= "30" { - t.Fatalf("%s failed: document #%s should not be returned", name, k) - } - } - } - - if dbRows, err := db.Query(`SELECT CROSS PARTITION * FROM tbltemp WHERE tbltemp.username>@1 AND tbltemp.grade>:2 WITH database=dbtemp`, "user1", 53); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else { - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - rows := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["id"]) - rows[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - if len(rows) != 24 { - t.Fatalf("%s failed: expected %#v but received %#v", name, 24, len(rows)) - } - for k := range rows { - if k <= "53" { - t.Fatalf("%s failed: document #%s should not be returned", name, k) - } - } - } - - if _, err := db.Query(`SELECT * FROM c WHERE c.username=$2 AND c.id>:10 ORDER BY c.id WITH database=dbtemp WITH collection=tbltemp`, "30", "user0"); err == nil || strings.Index(err.Error(), "no placeholder") < 0 { - t.Fatalf("%s failed: expecting 'no placeholder' but received %s", name, err) - } -} - -func Test_Query_SelectPkranges(t *testing.T) { - name := "Test_Query_SelectPkranges" - db := _openDb(t, name) - - db.Exec("DROP DATABASE IF EXISTS dbtemp") - db.Exec("CREATE DATABASE IF NOT EXISTS dbtemp") - defer db.Exec("DROP DATABASE IF EXISTS dbtemp") - if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - - var wait sync.WaitGroup - n := 1000 - d := 256 - wait.Add(n) - for i := 0; i < n; i++ { - go func(i int) { - id := fmt.Sprintf("%04d", i) - username := "user" + fmt.Sprintf("%02x", i%d) - email := "user" + strconv.Itoa(i) + "@domain.com" - db.Exec("INSERT INTO dbtemp.tbltemp (id,username,email,grade) VALUES (:1,@2,$3,:4)", id, username, email, i, username) - wait.Done() - }(i) - } - wait.Wait() - - query := `SELECT CROSS PARTITION * FROM c WHERE c.id>$1 ORDER BY c.id OFFSET 5 LIMIT 23 WITH database=dbtemp WITH collection=tbltemp` - if dbRows, err := db.Query(query, "0123"); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else { - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - rows := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["id"]) - rows[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - if len(rows) != 23 { - t.Fatalf("%s failed: expected %#v but received %#v", name, 23, len(rows)) - } - for k := range rows { - if k <= "0123" { - t.Fatalf("%s failed: document #%s should not be returned", name, k) - } - } - } - - query = `SELECT c.username, sum(c.index) FROM tbltemp c WHERE c.id<"0123" GROUP BY c.username OFFSET 110 LIMIT 20 WITH database=dbtemp WITH cross_partition=true` - if dbRows, err := db.Query(query); err != nil { - t.Fatalf("%s failed: %s", name, err) - } else { - colTypes, err := dbRows.ColumnTypes() - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - numCols := len(colTypes) - rows := make(map[string]map[string]interface{}) - for dbRows.Next() { - vals := make([]interface{}, numCols) - scanVals := make([]interface{}, numCols) - for i := 0; i < numCols; i++ { - scanVals[i] = &vals[i] - } - if err := dbRows.Scan(scanVals...); err == nil { - row := make(map[string]interface{}) - for i, v := range colTypes { - row[v.Name()] = vals[i] - } - id := fmt.Sprintf("%s", row["username"]) - rows[id] = row - } else if err != sql.ErrNoRows { - t.Fatalf("%s failed: %s", name, err) - } - } - if len(rows) != 13 { - t.Fatalf("%s failed: expected %#v but received %#v", name, 13, len(rows)) - } - } -} diff --git a/gocosmos_test.go b/gocosmos_test.go index f6dc0ea..e4cb052 100644 --- a/gocosmos_test.go +++ b/gocosmos_test.go @@ -1,69 +1,12 @@ package gocosmos import ( - "context" "database/sql" "os" "strings" "testing" ) -func Test_OpenDatabase(t *testing.T) { - name := "Test_OpenDatabase" - driver := "gocosmos" - dsn := "dummy" - db, err := sql.Open(driver, dsn) - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if db == nil { - t.Fatalf("%s failed: nil", name) - } -} - -func TestDriver_invalidConnectionString(t *testing.T) { - name := "TestDriver_invalidConnectionString" - driver := "gocosmos" - { - db, _ := sql.Open(driver, "AccountEndpoint;AccountKey=demo") - if err := db.Ping(); err == nil { - t.Fatalf("%s failed: should have error", name) - } - } - { - db, _ := sql.Open(driver, "AccountEndpoint=demo;AccountKey") - if err := db.Ping(); err == nil { - t.Fatalf("%s failed: should have error", name) - } - } - { - db, _ := sql.Open(driver, "AccountEndpoint=demo;AccountKey=demo/invalid_key") - if err := db.Ping(); err == nil { - t.Fatalf("%s failed: should have error", name) - } - } -} - -func TestDriver_missingEndpoint(t *testing.T) { - name := "TestDriver_missingEndpoint" - driver := "gocosmos" - dsn := "AccountKey=demo" - db, _ := sql.Open(driver, dsn) - if err := db.Ping(); err == nil { - t.Fatalf("%s failed: should have error", name) - } -} - -func TestDriver_missingAccountKey(t *testing.T) { - name := "TestDriver_missingAccountKey" - driver := "gocosmos" - dsn := "AccountEndpoint=demo" - db, _ := sql.Open(driver, dsn) - if err := db.Ping(); err == nil { - t.Fatalf("%s failed: should have error", name) - } -} - func _openDefaultDb(t *testing.T, testName, defaultDb string) *sql.DB { driver := "gocosmos" url := strings.ReplaceAll(os.Getenv("COSMOSDB_URL"), `"`, "") @@ -85,41 +28,3 @@ func _openDefaultDb(t *testing.T, testName, defaultDb string) *sql.DB { func _openDb(t *testing.T, testName string) *sql.DB { return _openDefaultDb(t, testName, "") } - -func TestDriver_Conn(t *testing.T) { - name := "TestDriver_Conn" - db := _openDb(t, name) - _, err := db.Conn(context.Background()) - if err != nil { - t.Fatalf("%s failed: %s", name, err) - } -} - -func TestDriver_Transaction(t *testing.T) { - name := "TestDriver_Transaction" - db := _openDb(t, name) - if tx, err := db.BeginTx(context.Background(), nil); tx != nil || err == nil { - t.Fatalf("%s failed: transaction is not supported yet", name) - } else if strings.Index(err.Error(), "not supported") < 0 { - t.Fatalf("%s failed: transaction is not supported yet / %s", name, err) - } -} - -func TestDriver_Open(t *testing.T) { - name := "TestDriver_Open" - db := _openDb(t, name) - if err := db.Ping(); err != nil { - t.Fatalf("%s failed: %s", name, err) - } -} - -func TestDriver_Close(t *testing.T) { - name := "TestDriver_Close" - db := _openDb(t, name) - if err := db.Ping(); err != nil { - t.Fatalf("%s failed: %s", name, err) - } - if err := db.Close(); err != nil { - t.Fatalf("%s failed: %s", name, err) - } -} diff --git a/stmt.go b/stmt.go index ef5ac66..efac824 100644 --- a/stmt.go +++ b/stmt.go @@ -242,47 +242,52 @@ func (s *Stmt) NumInput() int { /*----------------------------------------------------------------------*/ -func buildResultNoResultSet(restResponse *RestReponse, supportLastInsertId bool, rid string, ignoreErrorCode int) *ResultNoResultSet { - result := &ResultNoResultSet{ - err: restResponse.Error(), - lastInsertId: rid, - supportLastInsertId: supportLastInsertId, - } - if result.err == nil { - result.affectedRows = 1 - } - switch restResponse.StatusCode { +func normalizeError(statusCode, ignoreErrorCode int, err error) error { + switch statusCode { case 403: if ignoreErrorCode == 403 { - result.err = nil + return nil } else { - result.err = ErrForbidden + return ErrForbidden } case 404: if ignoreErrorCode == 404 { - result.err = nil + return nil } else { - result.err = ErrNotFound + return ErrNotFound } case 409: if ignoreErrorCode == 409 { - result.err = nil + return nil } else { - result.err = ErrConflict + return ErrConflict } case 412: if ignoreErrorCode == 412 { - result.err = nil + return nil } else { - result.err = ErrPreconditionFailure + return ErrPreconditionFailure } } + return err +} + +func buildResultNoResultSet(restResponse *RestReponse, supportLastInsertId bool, rid string, ignoreErrorCode int) *ResultNoResultSet { + result := &ResultNoResultSet{ + err: restResponse.Error(), + lastInsertId: rid, + supportLastInsertId: supportLastInsertId, + } + if result.err == nil { + result.affectedRows = 1 + } + result.err = normalizeError(restResponse.StatusCode, ignoreErrorCode, result.err) return result } // ResultNoResultSet captures the result from statements that do not expect a ResultSet to be returned. // -// @Available since v0.3.0 +// @Available since v0.2.1 type ResultNoResultSet struct { err error affectedRows int64 @@ -310,26 +315,44 @@ func (r *ResultNoResultSet) RowsAffected() (int64, error) { // ResultResultSet captures the result from statements that expect a ResultSet to be returned. // -// @Available since v0.3.0 +// @Available since v0.2.1 type ResultResultSet struct { err error count int cursorCount int columnList []string columnTypes map[string]reflect.Type - rowData []map[string]interface{} + rows []DocInfo + documents QueriedDocs } func (r *ResultResultSet) init() *ResultResultSet { - if r.rowData == nil { + if r.rows == nil && r.documents == nil { return r } + + if r.rows == nil { + documents := r.documents.AsDocInfoSlice() + if documents == nil { + // special case: result from a query like "SELECT COUNT(...)" + documents = make([]DocInfo, len(r.documents)) + for i, doc := range r.documents { + var docInfo DocInfo = map[string]interface{}{"$1": doc} + documents[i] = docInfo + } + } + for i, doc := range documents { + documents[i] = doc.RemoveSystemAttrs() + } + r.rows = documents + } + if r.columnTypes == nil { r.columnTypes = make(map[string]reflect.Type) } - r.count = len(r.rowData) + r.count = len(r.rows) colMap := make(map[string]bool) - for _, item := range r.rowData { + for _, item := range r.rows { for col, val := range item { colMap[col] = true if r.columnTypes[col] == nil { @@ -374,7 +397,7 @@ func (r *ResultResultSet) Next(dest []driver.Value) error { if r.cursorCount >= r.count { return io.EOF } - rowData := r.rowData[r.cursorCount] + rowData := r.rows[r.cursorCount] r.cursorCount++ for i, colName := range r.columnList { dest[i] = rowData[colName] diff --git a/stmt_collection.go b/stmt_collection.go index f56f287..31ee81f 100644 --- a/stmt_collection.go +++ b/stmt_collection.go @@ -278,9 +278,9 @@ func (s *StmtListCollections) Query(_ []driver.Value) (driver.Rows, error) { } if result.err == nil { result.count = len(restResult.Collections) - result.rowData = make([]map[string]interface{}, result.count) + result.rows = make([]DocInfo, result.count) for i, coll := range restResult.Collections { - result.rowData[i] = coll.toMap() + result.rows[i] = coll.toMap() } } switch restResult.StatusCode { diff --git a/stmt_database.go b/stmt_database.go index 0fb4d81..cf15992 100644 --- a/stmt_database.go +++ b/stmt_database.go @@ -204,9 +204,9 @@ func (s *StmtListDatabases) Query(_ []driver.Value) (driver.Rows, error) { } if result.err == nil { result.count = len(restResult.Databases) - result.rowData = make([]map[string]interface{}, result.count) + result.rows = make([]DocInfo, result.count) for i, db := range restResult.Databases { - result.rowData[i] = db.toMap() + result.rows[i] = db.toMap() } } switch restResult.StatusCode { diff --git a/stmt_document.go b/stmt_document.go index cb219cb..23c0c53 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -5,9 +5,7 @@ import ( "encoding/json" "errors" "fmt" - "io" "regexp" - "sort" "strconv" "strings" ) @@ -258,7 +256,10 @@ func (s *StmtDelete) Query(_ []driver.Value) (driver.Rows, error) { // // Syntax: // -// SELECT [CROSS PARTITION] ... FROM ... WITH database|db= [WITH collection|table=] [WITH cross_partition=true] +// SELECT [CROSS PARTITION] ... FROM ... +// WITH database|db= +// [WITH collection|table=] +// [WITH cross_partition=true] // // - (extension) If the collection is partitioned, specify "CROSS PARTITION" to allow execution across multiple partitions. // This clause is not required if query is to be executed on a single partition. @@ -321,8 +322,7 @@ func (s *StmtSelect) validate() error { return nil } -// Query implements driver.Stmt.Query. -// Upon successful call, this function returns (*ResultSelect, nil). +// Query implements driver.Stmt/Query. func (s *StmtSelect) Query(args []driver.Value) (driver.Rows, error) { params := make([]interface{}, 0) for i, arg := range args { @@ -340,83 +340,20 @@ func (s *StmtSelect) Query(args []driver.Value) (driver.Rows, error) { CrossPartitionEnabled: s.isCrossPartition, } - result := s.conn.restClient.QueryDocumentsCrossPartition(query) - err := result.Error() - var rows driver.Rows - if err == nil { - documents := result.Documents.AsDocInfoSlice() - if documents == nil { - documents = make([]DocInfo, len(result.Documents)) - for i, doc := range result.Documents { - var docInfo DocInfo = map[string]interface{}{"$1": doc} - documents[i] = docInfo - } - } - for i, doc := range documents { - documents[i] = doc.RemoveSystemAttrs() - } - rows = &ResultSelect{count: len(documents), documents: documents, cursorCount: 0, columnList: make([]string, 0)} - - // build column list - columnList := make([]string, 0) - columnMap := make(map[string]bool) - for _, doc := range documents { - for colName := range doc { - if _, ok := columnMap[colName]; !ok { - columnMap[colName] = true - columnList = append(columnList, colName) - } - } - } - sort.Strings(columnList) - rows.(*ResultSelect).columnList = columnList + restResult := s.conn.restClient.QueryDocumentsCrossPartition(query) + result := &ResultResultSet{err: restResult.Error(), columnList: make([]string, 0)} + if result.err == nil { + result.documents = restResult.Documents + result.init() } - switch result.StatusCode { - case 403: - err = ErrForbidden - case 404: - err = ErrNotFound - // case 409: - // err = ErrConflict - } - return rows, err + result.err = normalizeError(restResult.StatusCode, 0, result.err) + return result, result.err } -// Exec implements driver.Stmt.Exec. +// Exec implements driver.Stmt/Exec. // This function is not implemented, use Query instead. -func (s *StmtSelect) Exec(args []driver.Value) (driver.Result, error) { - return nil, errors.New("this operation is not supported, please use query") -} - -// ResultSelect captures the result from SELECT operation. -type ResultSelect struct { - count int - documents []DocInfo - cursorCount int - columnList []string -} - -// Columns implements driver.Rows.Columns. -func (r *ResultSelect) Columns() []string { - return r.columnList -} - -// Close implements driver.Rows.Close. -func (r *ResultSelect) Close() error { - return nil -} - -// Next implements driver.Rows.Next. -func (r *ResultSelect) Next(dest []driver.Value) error { - if r.cursorCount >= r.count { - return io.EOF - } - rowData := r.documents[r.cursorCount] - r.cursorCount++ - for i, colName := range r.columnList { - dest[i] = rowData[colName] - } - return nil +func (s *StmtSelect) Exec(_ []driver.Value) (driver.Result, error) { + return nil, ErrExecNotSupported } /*----------------------------------------------------------------------*/ diff --git a/stmt_document_select_test.go b/stmt_document_select_test.go new file mode 100644 index 0000000..80f6374 --- /dev/null +++ b/stmt_document_select_test.go @@ -0,0 +1,632 @@ +package gocosmos + +import ( + "database/sql" + "encoding/json" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "testing" + + "github.com/btnguyen2k/consu/reddo" +) + +func TestStmtSelect_Exec(t *testing.T) { + testName := "TestStmtSelect_Exec" + db := _openDb(t, testName) + _, err := db.Exec("SELECT * FROM c WITH db=db WITH collection=table") + if err != ErrExecNotSupported { + t.Fatalf("%s failed: expected ErrQueryNotSupported, but received %#v", testName, err) + } +} + +/*----------------------------------------------------------------------*/ + +func _testSelectPkValue(t *testing.T, testName string, db *sql.DB, collname string) { + low, high := 123, 987 + lowStr, highStr := fmt.Sprintf("%05d", low), fmt.Sprintf("%05d", high) + countPerPartition := _countPerPartition(low, high, dataList) + distinctPerPartition := _distinctPerPartition(low, high, dataList, "category") + var testCases = []queryTestCase{ + {name: "NoLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true"}, + {name: "OffsetLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5}, + {name: "NoLimit_OrderAsc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.grade WITH collection=%s WITH cross_partition=true", orderType: reddo.TypeInt, orderField: "grade", orderDirection: "asc"}, + {name: "OffsetLimit_OrderDesc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc"}, + + {name: "NoLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1}, + {name: "NoLimit_DistinctDoc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1}, + {name: "OffsetLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, expectedNumItems: 3}, + {name: "OffsetLimit_DistinctDoc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, expectedNumItems: 3}, + + {name: "NoLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc"}, + {name: "NoLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc"}, + {name: "OffsetLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc", expectedNumItems: 3}, + {name: "OffsetLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc", expectedNumItems: 3}, + + /* GROUP BY with ORDER BY is not supported! */ + {name: "NoLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "count"}, + {name: "OffsetLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "count"}, + {name: "NoLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "sum"}, + {name: "OffsetLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "sum"}, + {name: "NoLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "min"}, + {name: "OffsetLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "min"}, + {name: "NoLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "max"}, + {name: "OffsetLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "max"}, + {name: "NoLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "average"}, + {name: "OffsetLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "average"}, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + savedExpectedNumItems := testCase.expectedNumItems + for i := 0; i < numLogicalPartitions; i++ { + testCase.expectedNumItems = savedExpectedNumItems + expectedNumItems := testCase.expectedNumItems + username := "user" + strconv.Itoa(i) + params := []interface{}{lowStr, highStr, username} + if expectedNumItems <= 0 && testCase.maxItemCount <= 0 { + expectedNumItems = countPerPartition[username] + if testCase.distinctQuery != 0 { + expectedNumItems = distinctPerPartition[username] + } + testCase.expectedNumItems = expectedNumItems + } + sql := fmt.Sprintf(testCase.query, collname) + dbRows, err := db.Query(sql, params...) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, expectedNumItems, rows) + _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, rows) + _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, rows) + _verifyGroupBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, username, lowStr, highStr, rows) + } + }) + } +} + +func TestStmtSelect_Query_PkValue_SmallRU(t *testing.T) { + testName := "TestStmtSelect_Query_PkValue_SmallRU" + dbname := testDb + collname := testTable + client := _newRestClient(t, testName) + _initDataSmallRU(t, testName, client, dbname, collname, 1000) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count != 1 { + t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectPkValue(t, testName, db, collname) +} + +func TestStmtSelect_Query_PkValue_LargeRU(t *testing.T) { + testName := "TestStmtSelect_Query_PkValue_LargeRU" + dbname := testDb + collname := testTable + client := _newRestClient(t, testName) + _initDataLargeRU(t, testName, client, dbname, collname, 1000) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count < 2 { + t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectPkValue(t, testName, db, collname) +} + +/*----------------------------------------------------------------------*/ + +func _testSelectCrossPartition(t *testing.T, testName string, db *sql.DB, collname string) { + low, high := 123, 987 + lowStr, highStr := fmt.Sprintf("%05d", low), fmt.Sprintf("%05d", high) + var testCases = []queryTestCase{ + {name: "NoLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 WITH collection=%s WITH cross_partition=true"}, + {name: "OffsetLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5}, + {name: "NoLimit_OrderAsc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.grade WITH collection=%s WITH cross_partition=true", orderType: reddo.TypeInt, orderField: "grade", orderDirection: "asc"}, + {name: "OffsetLimit_OrderDesc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.category DESC OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc"}, + + {name: "NoLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, expectedNumItems: numCategories}, + {name: "NoLimit_DistinctDoc", query: "SELECT DISTINCT c.username FROM c WHERE $1<=c.id AND c.id<@2 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, expectedNumItems: numLogicalPartitions}, + {name: "OffsetLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.username FROM c WHERE $1<=c.id AND c.id<@2 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, expectedNumItems: 3}, + {name: "OffsetLimit_DistinctDoc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, expectedNumItems: 3}, + + {name: "NoLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.category WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc", expectedNumItems: numCategories}, + {name: "NoLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.username FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.username DESC WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeString, orderField: "username", orderDirection: "desc", expectedNumItems: numLogicalPartitions}, + {name: "OffsetLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc", expectedNumItems: 3}, + {name: "OffsetLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.username FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.username DESC OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeString, orderField: "username", orderDirection: "desc", expectedNumItems: 3}, + + /* GROUP BY with ORDER BY is not supported! */ + + {name: "NoLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "count"}, + {name: "OffsetLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "count"}, + {name: "NoLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "sum"}, + {name: "OffsetLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "sum"}, + {name: "NoLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "min"}, + {name: "OffsetLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "min"}, + {name: "NoLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "max"}, + {name: "OffsetLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "max"}, + {name: "NoLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "average"}, + {name: "OffsetLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "average"}, + } + params := []interface{}{lowStr, highStr} + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + expectedNumItems := high - low + if testCase.expectedNumItems > 0 { + expectedNumItems = testCase.expectedNumItems + } + sql := fmt.Sprintf(testCase.query, collname) + dbRows, err := db.Query(sql, params...) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, expectedNumItems, rows) + _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, rows) + _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, rows) + _verifyGroupBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, "", lowStr, highStr, rows) + }) + } +} + +func TestStmtSelect_Query_CrossPartition_SmallRU(t *testing.T) { + testName := "TestStmtSelect_Query_CrossPartition_SmallRU" + dbname := testDb + collname := testTable + client := _newRestClient(t, testName) + _initDataSmallRU(t, testName, client, dbname, collname, 1000) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count != 1 { + t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectCrossPartition(t, testName, db, collname) +} + +func TestStmtSelect_Query_CrossPartition_LargeRU(t *testing.T) { + testName := "TestStmtSelect_Query_CrossPartition_LargeRU" + dbname := testDb + collname := testTable + client := _newRestClient(t, testName) + _initDataLargeRU(t, testName, client, dbname, collname, 1000) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count < 2 { + t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectCrossPartition(t, testName, db, collname) +} + +/*----------------------------------------------------------------------*/ + +func _testSelectPaging(t *testing.T, testName string, db *sql.DB, collname string, pkranges *RespGetPkranges) { + low, high := 123, 987 + lowStr, highStr := fmt.Sprintf("%05d", low), fmt.Sprintf("%05d", high) + var testCases = []queryTestCase{ + {name: "Simple_OrderAsc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.id OFFSET :3 LIMIT 23 WITH collection=%s WITH cross_partition=true", maxItemCount: 23, orderField: "id", orderType: reddo.TypeString, orderDirection: "asc"}, + {name: "Simple_OrderDesc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.id DESC OFFSET :3 LIMIT 29 WITH collection=%s WITH cross_partition=true", maxItemCount: 29, orderField: "id", orderType: reddo.TypeString, orderDirection: "desc"}, + + {name: "DistinctDoc_OrderAsc", query: "SELECT DISTINCT c.username FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.username OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, orderField: "username", orderType: reddo.TypeString, orderDirection: "asc", expectedNumItems: numLogicalPartitions, distinctQuery: -1, distinctField: "username"}, + {name: "DistinctValue_OrderDesc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 ORDER BY c.category DESC OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, orderField: "$1", orderType: reddo.TypeInt, orderDirection: "desc", expectedNumItems: numCategories, distinctQuery: 1, distinctField: "$1"}, + + {name: "GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "count", expectedNumItems: numCategories}, + {name: "GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "sum", expectedNumItems: numCategories}, + {name: "GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "min", expectedNumItems: numCategories}, + {name: "GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "max", expectedNumItems: numCategories}, + {name: "GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 GROUP BY c.category OFFSET :3 LIMIT 3 WITH collection=%s WITH cross_partition=true", maxItemCount: 3, groupByAggr: "average", expectedNumItems: numCategories}, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + expectedNumItems := high - low + if testCase.expectedNumItems > 0 { + expectedNumItems = testCase.expectedNumItems + } + sql := fmt.Sprintf(testCase.query, collname) + offset := 0 + finalRows := make([]map[string]interface{}, 0) + for { + params := []interface{}{lowStr, highStr, offset} + dbRows, err := db.Query(sql, params...) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + if offset == 0 || len(rows) != 0 { + _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, 0, rows) + _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, rows) + _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, rows) + } + if len(rows) == 0 { + break + } + finalRows = append(finalRows, rows...) + offset += len(rows) + } + testCase.maxItemCount = 0 + // { + // for i, row := range finalRows { + // fmt.Printf("%5d: %s\n", i, row["id"]) + // } + // } + _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, expectedNumItems, finalRows) + _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, finalRows) + _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, finalRows) + _verifyGroupBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name, testCase, "", lowStr, highStr, finalRows) + }) + } +} + +func TestStmtSelect_Query_Paging_SmallRU(t *testing.T) { + testName := "TestStmtSelect_Query_Paging_SmallRU" + dbname := testDb + collname := testTable + client := _newRestClient(t, testName) + _initDataSmallRU(t, testName, client, dbname, collname, 1000) + pkranges := client.GetPkranges(dbname, collname) + if pkranges.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", pkranges.Error()) + } else if pkranges.Count != 1 { + t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, pkranges.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectPaging(t, testName, db, collname, pkranges) +} + +func TestStmtSelect_Query_Paging_LargeRU(t *testing.T) { + testName := "TestStmtSelect_Query_Paging_LargeRU" + dbname := testDb + collname := testTable + client := _newRestClient(t, testName) + _initDataLargeRU(t, testName, client, dbname, collname, 1000) + pkranges := client.GetPkranges(dbname, collname) + if pkranges.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", pkranges.Error()) + } else if pkranges.Count < 2 { + t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, pkranges.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectPaging(t, testName, db, collname, pkranges) +} + +/*----------------------------------------------------------------------*/ + +func _testSelectCustomDataset(t *testing.T, testName string, testCases []customQueryTestCase, db *sql.DB, collname string) { + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := fmt.Sprintf(testCase.query, collname) + dbRows, err := db.Query(sql) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + + var expectedResult []interface{} + json.Unmarshal([]byte(testCase.expectedResultJson), &expectedResult) + if len(rows) != len(expectedResult) { + t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/"+testCase.name, len(expectedResult), len(rows)) + } + if !testCase.ordering { + sort.Slice(rows, func(i, j int) bool { + var doci, docj = rows[i], rows[j] + stri, _ := json.Marshal(doci[testCase.compareField]) + strj, _ := json.Marshal(docj[testCase.compareField]) + return string(stri) < string(strj) + }) + sort.Slice(expectedResult, func(i, j int) bool { + var doci, docj = expectedResult[i].(map[string]interface{}), expectedResult[j].(map[string]interface{}) + stri, _ := json.Marshal(doci[testCase.compareField]) + strj, _ := json.Marshal(docj[testCase.compareField]) + return string(stri) < string(strj) + }) + } + for i, row := range rows { + expected := expectedResult[i] + if !reflect.DeepEqual(row, expected) { + // fmt.Printf("DEBUG: %#v\n", rows) + // fmt.Printf("DEBUG: %#v\n", expectedResult) + t.Fatalf("%s failed: result\n%#v\ndoes not match expected one\n%#v", testName+"/"+testCase.name, row, expected) + } + } + }) + } +} + +func _testSelectDatasetFamilies(t *testing.T, testName string, db *sql.DB, collname string) { + var testCases = []customQueryTestCase{ + // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/getting-started + // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/select + // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/from + // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/order-by + // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/group-by + // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/offset-limit + {name: "QuerySingleDoc", compareField: "id", query: `SELECT * FROM Families f WHERE f.id = "AndersenFamily" WITH collection=%s WITH cross_partition=true`, expectedResultJson: _toJson([]DocInfo{dataMapFamilies["AndersenFamily"]})}, + {name: "QuerySingleAttr", compareField: "id", query: `SELECT f.address FROM Families f WHERE f.id = "AndersenFamily" WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"address":{"state":"WA","county":"King","city":"Seattle"}}]`}, + {name: "QuerySubAttrs", compareField: "id", query: `SELECT {"Name":f.id, "City":f.address.city} AS Family FROM Families f WHERE f.address.city = f.address.state WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"Family":{"Name":"WakefieldFamily","City":"NY"}}]`}, + {name: "QuerySubItems1", compareField: "$1", query: `SELECT * FROM Families.children WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"$1":[{"firstName":"Henriette Thaulow","gender":"female","grade":5,"pets":[{"givenName":"Fluffy"}]}]},{"$1":[{"familyName":"Merriam","gender":"female","givenName":"Jesse","grade":1,"pets":[{"givenName":"Goofy"},{"givenName":"Shadow"}]},{"familyName":"Miller","gender":"female","givenName":"Lisa","grade":8}]}]`}, + {name: "QuerySubItems2", compareField: "$1", query: `SELECT * FROM Families.address.state WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"$1":"WA"},{"$1":"NY"}]`}, + {name: "QuerySingleAttrWithOrderBy", ordering: true, query: `SELECT c.givenName FROM Families f JOIN c IN f.children WHERE f.id = 'WakefieldFamily' ORDER BY f.address.city ASC WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"givenName":"Jesse"},{"givenName":"Lisa"}]`}, + {name: "QuerySubAttrsWithOrderByAsc", ordering: true, query: `SELECT f.id, f.address.city FROM Families f ORDER BY f.address.city WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"WakefieldFamily","city":"NY"},{"id":"AndersenFamily","city":"Seattle"}]`}, + {name: "QuerySubAttrsWithOrderByDesc", ordering: true, query: `SELECT f.id, f.creationDate FROM Families f ORDER BY f.creationDate DESC WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"AndersenFamily","creationDate":1431620472},{"id":"WakefieldFamily","creationDate":1431620462}]`}, + {name: "QuerySubAttrsWithOrderByMissingField", ordering: false, query: `SELECT f.id, f.lastName FROM Families f ORDER BY f.lastName WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"WakefieldFamily","lastName":null},{"id":"AndersenFamily","lastName":"Andersen"}]`}, + {name: "QueryGroupBy", compareField: "$1", query: `SELECT COUNT(UniqueLastNames) FROM (SELECT AVG(f.age) FROM f GROUP BY f.lastName) AS UniqueLastNames WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"$1":2}]`}, + {name: "QueryOffsetLimitWithOrderBy", compareField: "id", query: `SELECT f.id, f.address.city FROM Families f ORDER BY f.address.city OFFSET 1 LIMIT 1 WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"AndersenFamily","city":"Seattle"}]`}, + // without ORDER BY, the returned result is un-deterministic + // {name: "QueryOffsetLimitWithoutOrderBy", compareField: "id", query: `SELECT f.id, f.address.city FROM Families f OFFSET 1 LIMIT 1 WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"id":"AndersenFamily","city":"Seattle"}]`}, + } + _testSelectCustomDataset(t, testName, testCases, db, collname) +} + +func TestStmtSelect_Query_DatasetFamilies_SmallRU(t *testing.T) { + testName := "TestStmtSelect_Query_DatasetFamilies_SmallRU" + client := _newRestClient(t, testName) + dbname := testDb + collname := testTable + _initDataFamliesSmallRU(t, testName, client, dbname, collname) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count != 1 { + t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectDatasetFamilies(t, testName, db, collname) +} + +func TestStmtSelect_Query_DatasetFamilies_LargeRU(t *testing.T) { + testName := "TestStmtSelect_Query_DatasetFamilies_LargeRU" + client := _newRestClient(t, testName) + dbname := testDb + collname := testTable + _initDataFamliesLargeRU(t, testName, client, dbname, collname) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count < 2 { + t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectDatasetFamilies(t, testName, db, collname) +} + +/*----------------------------------------------------------------------*/ + +func TestStmtSelect_Query(t *testing.T) { + testName := "TestStmtSelect_Query" + db := _openDb(t, testName) + dbname := "dbtemp" + db.Exec("DROP DATABASE IF EXISTS db_not_exists") + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + if _, err := db.Exec("CREATE COLLECTION dbtemp.tbltemp WITH pk=/username WITH uk=/email"); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + for i := 0; i < 100; i++ { + id := fmt.Sprintf("%02d", i) + username := "user" + strconv.Itoa(i%4) + db.Exec(fmt.Sprintf("INSERT INTO %s.tbltemp (id,username,email,grade) VALUES (:1,@2,$3,:4)", dbname), id, username, "user"+id+"@domain.com", i, username) + } + + if dbRows, err := db.Query(fmt.Sprintf(`SELECT * FROM c WHERE c.username="user0" AND c.id>"30" ORDER BY c.id WITH database=%s WITH collection=tbltemp`, dbname)); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } else { + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + if len(rows) != 17 { + t.Fatalf("%s failed: expected %#v but received %#v", testName, 17, len(rows)) + } + for _, row := range rows { + id := row["id"].(string) + if id <= "30" { + t.Fatalf("%s failed: document #%s should not be returned", testName, id) + } + } + } + + if dbRows, err := db.Query(fmt.Sprintf(`SELECT CROSS PARTITION * FROM tbltemp c WHERE c.username>"user1" AND c.id>"53" WITH database=%s`, dbname)); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } else { + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + if len(rows) != 24 { + t.Fatalf("%s failed: expected %#v but received %#v", testName, 24, len(rows)) + } + for _, row := range rows { + id := row["id"].(string) + if id <= "53" { + t.Fatalf("%s failed: document #%s should not be returned", testName, id) + } + } + } + + if _, err := db.Query(fmt.Sprintf(`SELECT * FROM c WITH db=%s WITH collection=tbl_not_found`, dbname)); err != ErrNotFound { + t.Fatalf("%s failed: expected ErrNotFound but received %#v", testName, err) + } + + if _, err := db.Query(`SELECT * FROM c WITH db=db_not_found WITH collection=tbltemp`); err != ErrNotFound { + t.Fatalf("%s failed: expected ErrNotFound but received %#v", testName, err) + } +} + +func TestStmtSelect_Query_SelectLongList(t *testing.T) { + testName := "TestStmtSelect_Query_SelectLongList" + db := _openDb(t, testName) + dbname := "dbtemp" + db.Exec("DROP DATABASE IF EXISTS db_not_exists") + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + if _, err := db.Exec(fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname)); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + for i := 0; i < 1000; i++ { + id := fmt.Sprintf("%03d", i) + username := "user" + strconv.Itoa(i%4) + db.Exec(fmt.Sprintf("INSERT INTO %s.tbltemp (id,username,email,grade) VALUES (:1,@2,$3,:4)", dbname), id, username, "user"+id+"@domain.com", i, username) + } + + if dbRows, err := db.Query(fmt.Sprintf(`SELECT * FROM c WHERE c.username="user0" AND c.id>"030" ORDER BY c.id WITH database=%s WITH collection=tbltemp`, dbname)); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } else { + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + if len(rows) != 242 { + t.Fatalf("%s failed: expected %#v but received %#v", testName, 242, len(rows)) + } + for _, row := range rows { + id := row["id"].(string) + if id <= "030" { + t.Fatalf("%s failed: document #%s should not be returned", testName, id) + } + } + } +} + +func TestStmtSelect_Query_SelectPlaceholder(t *testing.T) { + testName := "TestStmtSelect_Query_SelectPlaceholder" + db := _openDb(t, testName) + dbname := "dbtemp" + db.Exec("DROP DATABASE IF EXISTS db_not_exists") + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + if _, err := db.Exec(fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname)); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + for i := 0; i < 100; i++ { + id := fmt.Sprintf("%02d", i) + username := "user" + strconv.Itoa(i%4) + db.Exec(fmt.Sprintf("INSERT INTO %s.tbltemp (id,username,email,grade) VALUES (:1,@2,$3,:4)", dbname), id, username, "user"+id+"@domain.com", i, username) + } + + if dbRows, err := db.Query(fmt.Sprintf(`SELECT * FROM c WHERE c.username=$2 AND c.id>:1 ORDER BY c.id WITH database=%s WITH collection=tbltemp`, dbname), "30", "user0"); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } else { + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + if len(rows) != 17 { + t.Fatalf("%s failed: expected %#v but received %#v", testName, 17, len(rows)) + } + for _, row := range rows { + id := row["id"].(string) + if id <= "30" { + t.Fatalf("%s failed: document #%s should not be returned", testName, id) + } + } + } + + if dbRows, err := db.Query(fmt.Sprintf(`SELECT CROSS PARTITION * FROM tbltemp WHERE tbltemp.username>@1 AND tbltemp.grade>:2 WITH database=%s`, dbname), "user1", 53); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } else { + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + if len(rows) != 24 { + t.Fatalf("%s failed: expected %#v but received %#v", testName, 24, len(rows)) + } + for _, row := range rows { + id := row["id"].(string) + if id <= "53" { + t.Fatalf("%s failed: document #%s should not be returned", testName, id) + } + } + } + + if _, err := db.Query(fmt.Sprintf(`SELECT * FROM c WHERE c.username=$2 AND c.id>:10 ORDER BY c.id WITH database=%s WITH collection=tbltemp`, dbname), "30", "user0"); err == nil || strings.Index(err.Error(), "no placeholder") < 0 { + t.Fatalf("%s failed: expecting 'no placeholder' but received %s", testName, err) + } +} + +func TestStmtSelect_Query_SelectPkranges(t *testing.T) { + testName := "TestStmtSelect_Query_SelectPkranges" + db := _openDb(t, testName) + dbname := "dbtemp" + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbname)) + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + if _, err := db.Exec(fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname)); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + var wait sync.WaitGroup + n := 1000 + d := 256 + wait.Add(n) + for i := 0; i < n; i++ { + go func(i int) { + id := fmt.Sprintf("%04d", i) + username := "user" + fmt.Sprintf("%02x", i%d) + email := "user" + strconv.Itoa(i) + "@domain.com" + db.Exec(fmt.Sprintf("INSERT INTO %s.tbltemp (id,username,email,grade) VALUES (:1,@2,$3,:4)", dbname), id, username, email, i, username) + wait.Done() + }(i) + } + wait.Wait() + + query := fmt.Sprintf(`SELECT CROSS PARTITION * FROM c WHERE c.id>$1 ORDER BY c.id OFFSET 5 LIMIT 23 WITH database=%s WITH collection=tbltemp`, dbname) + if dbRows, err := db.Query(query, "0123"); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } else { + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + if len(rows) != 23 { + t.Fatalf("%s failed: expected %#v but received %#v", testName, 23, len(rows)) + } + for _, row := range rows { + id := row["id"].(string) + if id <= "0123" { + t.Fatalf("%s failed: document #%s should not be returned", testName, id) + } + } + } + + query = fmt.Sprintf(`SELECT c.username, sum(c.index) FROM tbltemp c WHERE c.id<"0123" GROUP BY c.username OFFSET 110 LIMIT 20 WITH database=%s WITH cross_partition=true`, dbname) + if dbRows, err := db.Query(query); err != nil { + t.Fatalf("%s failed: %s", testName, err) + } else { + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName, err) + } + + if len(rows) != 13 { + t.Fatalf("%s failed: expected %#v but received %#v", testName, 13, len(rows)) + } + } +} diff --git a/stmt_test.go b/stmt_test.go index 3c14fad..16ba149 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -1,9 +1,36 @@ package gocosmos import ( + "database/sql" "testing" ) +func _fetchAllRows(dbRows *sql.Rows) ([]map[string]interface{}, error) { + colTypes, err := dbRows.ColumnTypes() + if err != nil { + return nil, err + } + numCols := len(colTypes) + rows := make([]map[string]interface{}, 0) + for dbRows.Next() { + vals := make([]interface{}, numCols) + scanVals := make([]interface{}, numCols) + for i := 0; i < numCols; i++ { + scanVals[i] = &vals[i] + } + if err := dbRows.Scan(scanVals...); err == nil { + row := make(map[string]interface{}) + for i, v := range colTypes { + row[v.Name()] = vals[i] + } + rows = append(rows, row) + } else if err != sql.ErrNoRows { + return nil, err + } + } + return rows, nil +} + func TestStmt_NumInput(t *testing.T) { name := "TestStmt_NumInput" testData := map[string]int{ From fa51589eef4f664fb569d1791ce4bd170f543dc4 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Thu, 8 Jun 2023 13:19:57 +1000 Subject: [PATCH 11/13] update tests --- data_test.go | 6 ++--- gocosmos_select_test.go | 50 ------------------------------------ stmt_document_select_test.go | 44 +++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 53 deletions(-) delete mode 100644 gocosmos_select_test.go diff --git a/data_test.go b/data_test.go index 2564cda..80552f8 100644 --- a/data_test.go +++ b/data_test.go @@ -28501,8 +28501,8 @@ func _initDataNutrition(t *testing.T, testName string, client *RestClient, db, c for { now := time.Now() d := now.Sub(start) - r := float64(numDocWritten) / (d.Seconds() + 0.01) - if r <= 123.0 { + r := float64(numDocWritten) / (d.Seconds() + 0.001) + if r <= 81.19 { break } fmt.Printf("\t[DEBUG] too fast, slowing down...(Id: %d / NumDocs: %d / Dur: %.3f / Rate: %.3f)\n", id, numDocWritten, d.Seconds(), r) @@ -28520,7 +28520,7 @@ func _initDataNutrition(t *testing.T, testName string, client *RestClient, db, c { now := time.Now() d := now.Sub(start) - r := float64(numDocWritten) / (d.Seconds() + 0.01) + r := float64(numDocWritten) / (d.Seconds() + 0.001) fmt.Printf("\t[DEBUG] Dur: %.3f / Rate: %.3f\n", d.Seconds(), r) time.Sleep(1*time.Second + time.Duration(rand.Intn(1234))*time.Millisecond) } diff --git a/gocosmos_select_test.go b/gocosmos_select_test.go deleted file mode 100644 index a3d327b..0000000 --- a/gocosmos_select_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package gocosmos - -import ( - "database/sql" - "testing" -) - -func _testSelectDatasetNutrition(t *testing.T, testName string, db *sql.DB, collname string) { - var testCases = []customQueryTestCase{ - // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/group-by - {name: "Count", query: `SELECT COUNT(1) AS foodGroupCount FROM Food f WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"foodGroupCount": 8618}]`}, - {name: "QueryGroupBy1", compareField: "foodGroupCount", - query: "SELECT COUNT(1) AS foodGroupCount, UPPER(f.foodGroup) AS upperFoodGroup FROM Food f GROUP BY UPPER(f.foodGroup) WITH collection=%s WITH cross_partition=true", - expectedResultJson: `[{"foodGroupCount":64,"upperFoodGroup":"SPICES AND HERBS"},{"foodGroupCount":108,"upperFoodGroup":"RESTAURANT FOODS"},{"foodGroupCount":113,"upperFoodGroup":"MEALS, ENTREES, AND SIDE DISHES"},{"foodGroupCount":133,"upperFoodGroup":"NUT AND SEED PRODUCTS"},{"foodGroupCount":165,"upperFoodGroup":"AMERICAN INDIAN/ALASKA NATIVE FOODS"},{"foodGroupCount":171,"upperFoodGroup":"SNACKS"},{"foodGroupCount":183,"upperFoodGroup":"CEREAL GRAINS AND PASTA"},{"foodGroupCount":219,"upperFoodGroup":"FATS AND OILS"},{"foodGroupCount":244,"upperFoodGroup":"SAUSAGES AND LUNCHEON MEATS"},{"foodGroupCount":264,"upperFoodGroup":"DAIRY AND EGG PRODUCTS"},{"foodGroupCount":267,"upperFoodGroup":"FINFISH AND SHELLFISH PRODUCTS"},{"foodGroupCount":315,"upperFoodGroup":"BEVERAGES"},{"foodGroupCount":343,"upperFoodGroup":"PORK PRODUCTS"},{"foodGroupCount":346,"upperFoodGroup":"FRUITS AND FRUIT JUICES"},{"foodGroupCount":347,"upperFoodGroup":"SWEETS"},{"foodGroupCount":362,"upperFoodGroup":"BABY FOODS"},{"foodGroupCount":363,"upperFoodGroup":"BREAKFAST CEREALS"},{"foodGroupCount":371,"upperFoodGroup":"FAST FOODS"},{"foodGroupCount":389,"upperFoodGroup":"LEGUMES AND LEGUME PRODUCTS"},{"foodGroupCount":390,"upperFoodGroup":"POULTRY PRODUCTS"},{"foodGroupCount":438,"upperFoodGroup":"LAMB, VEAL, AND GAME PRODUCTS"},{"foodGroupCount":452,"upperFoodGroup":"SOUPS, SAUCES, AND GRAVIES"},{"foodGroupCount":797,"upperFoodGroup":"BAKED PRODUCTS"},{"foodGroupCount":828,"upperFoodGroup":"VEGETABLES AND VEGETABLE PRODUCTS"},{"foodGroupCount":946,"upperFoodGroup":"BEEF PRODUCTS"}]`}, - {name: "QueryGroupBy2", compareField: "foodGroupCount", - query: `SELECT COUNT(1) AS foodGroupCount, ARRAY_CONTAINS(f.tags, {name: 'orange'}) AS containsOrangeTag, f.version BETWEEN 0 AND 2 AS correctVersion FROM Food f GROUP BY ARRAY_CONTAINS(f.tags, {name: 'orange'}), f.version BETWEEN 0 AND 2 WITH collection=%s WITH cross_partition=true`, - expectedResultJson: `[{"foodGroupCount":10,"containsOrangeTag":true,"correctVersion":true},{"foodGroupCount":8608,"containsOrangeTag":false,"correctVersion":true}]`}, - } - _testSelectCustomDataset(t, testName, testCases, db, collname) -} - -func TestSelect_DatasetNutrition_SmallRU(t *testing.T) { - testName := "TestSelect_DatasetNutrition_SmallRU" - client := _newRestClient(t, testName) - dbname := testDb - collname := testTable - _initDataNutritionSmallRU(t, testName, client, dbname, collname) - if result := client.GetPkranges(dbname, collname); result.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) - } else if result.Count != 1 { - t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectDatasetNutrition(t, testName, db, collname) -} - -func TestSelect_DatasetNutrition_LargeRU(t *testing.T) { - testName := "TestSelect_DatasetNutrition_LargeRU" - client := _newRestClient(t, testName) - dbname := testDb - collname := testTable - _initDataNutritionLargeRU(t, testName, client, dbname, collname) - if result := client.GetPkranges(dbname, collname); result.Error() != nil { - t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) - } else if result.Count < 2 { - t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) - } - db := _openDefaultDb(t, testName, dbname) - _testSelectDatasetNutrition(t, testName, db, collname) -} diff --git a/stmt_document_select_test.go b/stmt_document_select_test.go index 80f6374..dda1399 100644 --- a/stmt_document_select_test.go +++ b/stmt_document_select_test.go @@ -404,6 +404,50 @@ func TestStmtSelect_Query_DatasetFamilies_LargeRU(t *testing.T) { _testSelectDatasetFamilies(t, testName, db, collname) } +func _testSelectDatasetNutrition(t *testing.T, testName string, db *sql.DB, collname string) { + var testCases = []customQueryTestCase{ + // ref: https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/group-by + {name: "Count", query: `SELECT COUNT(1) AS foodGroupCount FROM Food f WITH collection=%s WITH cross_partition=true`, expectedResultJson: `[{"foodGroupCount": 8618}]`}, + {name: "QueryGroupBy1", compareField: "foodGroupCount", + query: "SELECT COUNT(1) AS foodGroupCount, UPPER(f.foodGroup) AS upperFoodGroup FROM Food f GROUP BY UPPER(f.foodGroup) WITH collection=%s WITH cross_partition=true", + expectedResultJson: `[{"foodGroupCount":64,"upperFoodGroup":"SPICES AND HERBS"},{"foodGroupCount":108,"upperFoodGroup":"RESTAURANT FOODS"},{"foodGroupCount":113,"upperFoodGroup":"MEALS, ENTREES, AND SIDE DISHES"},{"foodGroupCount":133,"upperFoodGroup":"NUT AND SEED PRODUCTS"},{"foodGroupCount":165,"upperFoodGroup":"AMERICAN INDIAN/ALASKA NATIVE FOODS"},{"foodGroupCount":171,"upperFoodGroup":"SNACKS"},{"foodGroupCount":183,"upperFoodGroup":"CEREAL GRAINS AND PASTA"},{"foodGroupCount":219,"upperFoodGroup":"FATS AND OILS"},{"foodGroupCount":244,"upperFoodGroup":"SAUSAGES AND LUNCHEON MEATS"},{"foodGroupCount":264,"upperFoodGroup":"DAIRY AND EGG PRODUCTS"},{"foodGroupCount":267,"upperFoodGroup":"FINFISH AND SHELLFISH PRODUCTS"},{"foodGroupCount":315,"upperFoodGroup":"BEVERAGES"},{"foodGroupCount":343,"upperFoodGroup":"PORK PRODUCTS"},{"foodGroupCount":346,"upperFoodGroup":"FRUITS AND FRUIT JUICES"},{"foodGroupCount":347,"upperFoodGroup":"SWEETS"},{"foodGroupCount":362,"upperFoodGroup":"BABY FOODS"},{"foodGroupCount":363,"upperFoodGroup":"BREAKFAST CEREALS"},{"foodGroupCount":371,"upperFoodGroup":"FAST FOODS"},{"foodGroupCount":389,"upperFoodGroup":"LEGUMES AND LEGUME PRODUCTS"},{"foodGroupCount":390,"upperFoodGroup":"POULTRY PRODUCTS"},{"foodGroupCount":438,"upperFoodGroup":"LAMB, VEAL, AND GAME PRODUCTS"},{"foodGroupCount":452,"upperFoodGroup":"SOUPS, SAUCES, AND GRAVIES"},{"foodGroupCount":797,"upperFoodGroup":"BAKED PRODUCTS"},{"foodGroupCount":828,"upperFoodGroup":"VEGETABLES AND VEGETABLE PRODUCTS"},{"foodGroupCount":946,"upperFoodGroup":"BEEF PRODUCTS"}]`}, + {name: "QueryGroupBy2", compareField: "foodGroupCount", + query: `SELECT COUNT(1) AS foodGroupCount, ARRAY_CONTAINS(f.tags, {name: 'orange'}) AS containsOrangeTag, f.version BETWEEN 0 AND 2 AS correctVersion FROM Food f GROUP BY ARRAY_CONTAINS(f.tags, {name: 'orange'}), f.version BETWEEN 0 AND 2 WITH collection=%s WITH cross_partition=true`, + expectedResultJson: `[{"foodGroupCount":10,"containsOrangeTag":true,"correctVersion":true},{"foodGroupCount":8608,"containsOrangeTag":false,"correctVersion":true}]`}, + } + _testSelectCustomDataset(t, testName, testCases, db, collname) +} + +func TestStmtSelect_Query_DatasetNutrition_SmallRU(t *testing.T) { + testName := "TestStmtSelect_Query_DatasetNutrition_SmallRU" + client := _newRestClient(t, testName) + dbname := testDb + collname := testTable + _initDataNutritionSmallRU(t, testName, client, dbname, collname) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count != 1 { + t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectDatasetNutrition(t, testName, db, collname) +} + +func TestStmtSelect_Query_DatasetNutrition_LargeRU(t *testing.T) { + testName := "TestStmtSelect_Query_DatasetNutrition_LargeRU" + client := _newRestClient(t, testName) + dbname := testDb + collname := testTable + _initDataNutritionLargeRU(t, testName, client, dbname, collname) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count < 2 { + t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectDatasetNutrition(t, testName, db, collname) +} + /*----------------------------------------------------------------------*/ func TestStmtSelect_Query(t *testing.T) { From 5a4664838987f2a0000a843e57fa97b29d52cdf9 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Thu, 8 Jun 2023 13:46:32 +1000 Subject: [PATCH 12/13] update tests --- .github/workflows/gocosmos.yaml | 30 ++---------------------------- gocosmos.go | 2 +- gocosmos_test.go | 30 ------------------------------ stmt_test.go | 24 ++++++++++++++++++++++++ 4 files changed, 27 insertions(+), 59 deletions(-) delete mode 100644 gocosmos_test.go diff --git a/.github/workflows/gocosmos.yaml b/.github/workflows/gocosmos.yaml index b480a51..78fb30b 100644 --- a/.github/workflows/gocosmos.yaml +++ b/.github/workflows/gocosmos.yaml @@ -123,39 +123,13 @@ jobs: netstat -nt $env:COSMOSDB_DRIVER='gocosmos' $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' - go test -cover -coverprofile="coverage_driver_document_query.txt" -v -count 1 -p 1 -run "TestStmtSelect_(Exec|Query)" . + go test -timeout 30m -cover -coverprofile="coverage_driver_document_query.txt" -v -count 1 -p 1 -run "TestStmtSelect_(Exec|Query)" . - name: Codecov uses: codecov/codecov-action@v3 with: flags: driver_document_query name: driver_document_query - testDriverSelect: - name: Test driver SELECT query - runs-on: windows-latest - steps: - - name: Set up Go env - uses: actions/setup-go@v2 - with: - go-version: ^1.13 - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - name: Test - run: | - choco install azure-cosmosdb-emulator - & "C:\Program Files\Azure Cosmos DB Emulator\Microsoft.Azure.Cosmos.Emulator.exe" /DisableRateLimiting /NoUI /NoExplorer - Start-Sleep -s 60 - try { Invoke-RestMethod -Method GET https://127.0.0.1:8081/ } catch {} - netstat -nt - $env:COSMOSDB_DRIVER='gocosmos' - $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' - go test -cover -coverprofile="coverage_driver_select.txt" -v -count 1 -p 1 -run "TestSelect_" . - - name: Codecov - uses: codecov/codecov-action@v3 - with: - flags: driver_select - name: driver_select - testRestClientNonQuery: name: Test RestClient non-query runs-on: windows-latest @@ -201,7 +175,7 @@ jobs: netstat -nt $env:COSMOSDB_DRIVER='gocosmos' $env:COSMOSDB_URL='AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==' - go test -cover -coverprofile="coverage_restclient_query.txt" -v -count 1 -p 1 -run "TestRestClient_QueryDocuments" . + go test -timeout 30m -cover -coverprofile="coverage_restclient_query.txt" -v -count 1 -p 1 -run "TestRestClient_QueryDocuments" . - name: Codecov uses: codecov/codecov-action@v3 with: diff --git a/gocosmos.go b/gocosmos.go index 0fbc91f..ee41912 100644 --- a/gocosmos.go +++ b/gocosmos.go @@ -7,7 +7,7 @@ import ( const ( // Version of package gocosmos. - Version = "0.3.0" + Version = "0.2.1" ) func goTypeToCosmosDbType(typ reflect.Type) string { diff --git a/gocosmos_test.go b/gocosmos_test.go deleted file mode 100644 index e4cb052..0000000 --- a/gocosmos_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gocosmos - -import ( - "database/sql" - "os" - "strings" - "testing" -) - -func _openDefaultDb(t *testing.T, testName, defaultDb string) *sql.DB { - driver := "gocosmos" - url := strings.ReplaceAll(os.Getenv("COSMOSDB_URL"), `"`, "") - if url == "" { - t.Skipf("%s skipped", testName) - } - if defaultDb != "" { - if strings.Index(url, "DefaultDb=") < 0 { - url += ";DefaultDb=" + defaultDb - } - } - db, err := sql.Open(driver, url) - if err != nil { - t.Fatalf("%s failed: %s", testName+"/sql.Open", err) - } - return db -} - -func _openDb(t *testing.T, testName string) *sql.DB { - return _openDefaultDb(t, testName, "") -} diff --git a/stmt_test.go b/stmt_test.go index 16ba149..609f904 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -2,9 +2,33 @@ package gocosmos import ( "database/sql" + "os" + "strings" "testing" ) +func _openDefaultDb(t *testing.T, testName, defaultDb string) *sql.DB { + driver := "gocosmos" + url := strings.ReplaceAll(os.Getenv("COSMOSDB_URL"), `"`, "") + if url == "" { + t.Skipf("%s skipped", testName) + } + if defaultDb != "" { + if strings.Index(url, "DefaultDb=") < 0 { + url += ";DefaultDb=" + defaultDb + } + } + db, err := sql.Open(driver, url) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/sql.Open", err) + } + return db +} + +func _openDb(t *testing.T, testName string) *sql.DB { + return _openDefaultDb(t, testName, "") +} + func _fetchAllRows(dbRows *sql.Rows) ([]map[string]interface{}, error) { colTypes, err := dbRows.ColumnTypes() if err != nil { From 33de9249cbd01bd634181f272154a2276baa7742 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Thu, 8 Jun 2023 15:08:02 +1000 Subject: [PATCH 13/13] prepare to release v0.2.1 --- data_test.go | 11 ++++++++--- stmt.go | 15 ++++++++++----- stmt_collection_parsing_test.go | 2 +- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/data_test.go b/data_test.go index 80552f8..116b17d 100644 --- a/data_test.go +++ b/data_test.go @@ -28475,7 +28475,7 @@ const _testDataVolcano = ` func _initDataNutrition(t *testing.T, testName string, client *RestClient, db, container string) { dataListNutrition := make([]DocInfo, 0) - dataMapNutrition := make(map[string]DocInfo) + dataMapNutrition := sync.Map{} err := json.Unmarshal([]byte(_testDataNutrition), &dataListNutrition) if err != nil { t.Fatalf("%s failed: %s", testName, err) @@ -28493,7 +28493,7 @@ func _initDataNutrition(t *testing.T, testName string, client *RestClient, db, c defer wg.Done() for doc := range buff { docId := doc["id"].(string) - dataMapNutrition[docId] = doc + dataMapNutrition.Store(docId, doc) if result := client.CreateDocument(DocumentSpec{DbName: db, CollName: container, PartitionKeyValues: []interface{}{docId}, DocumentData: doc}); result.Error() != nil { t.Fatalf("%s failed: (%#v) %s", testName, id, result.Error()) } @@ -28524,7 +28524,12 @@ func _initDataNutrition(t *testing.T, testName string, client *RestClient, db, c fmt.Printf("\t[DEBUG] Dur: %.3f / Rate: %.3f\n", d.Seconds(), r) time.Sleep(1*time.Second + time.Duration(rand.Intn(1234))*time.Millisecond) } - fmt.Printf("\tDataset: %#v / (checksum) Number of records: %#v\n", "Nutrition", len(dataMapNutrition)) + count := 0 + dataMapNutrition.Range(func(_, _ interface{}) bool { + count++ + return true + }) + fmt.Printf("\tDataset: %#v / (checksum) Number of records: %#v\n", "Nutrition", count) } func _initDataNutritionSmallRU(t *testing.T, testName string, client *RestClient, db, container string) { diff --git a/stmt.go b/stmt.go index efac824..fd85955 100644 --- a/stmt.go +++ b/stmt.go @@ -15,7 +15,6 @@ const ( ifNotExists = `(\s+IF\s+NOT\s+EXISTS)?` ifExists = `(\s+IF\s+EXISTS)?` with = `(\s+WITH\s+` + field + `\s*=\s*([\w/\.\*,;:'"-]+)((\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `\s*=\s*([\w/\.\*,;:'"-]+))*)?` - // with = `((\s+WITH\s+([\w-]+)\s*=\s*([\w/\.,;:'"-]+))*)` ) var ( @@ -217,15 +216,21 @@ type Stmt struct { withOpts map[string]string } -var reWithOpt = regexp.MustCompile(`(?i)WITH\s+([\w-]+)\s*=\s*([\w/\.,;:'"-]+)`) +var reWithOpts = regexp.MustCompile(`(?is)^(\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `\s*=\s*([\w/\.\*,;:'"-]+)`) // parseWithOpts parses "WITH..." clause and store result in withOpts map. // This function returns no error. Sub-implementations may override this behavior. func (s *Stmt) parseWithOpts(withOptsStr string) error { + withOptsStr = " " + withOptsStr s.withOpts = make(map[string]string) - tokens := reWithOpt.FindAllStringSubmatch(withOptsStr, -1) - for _, token := range tokens { - s.withOpts[strings.TrimSpace(strings.ToUpper(token[1]))] = strings.TrimSpace(token[2]) + for { + matches := reWithOpts.FindStringSubmatch(withOptsStr) + if matches == nil { + break + } + k := strings.TrimSpace(strings.ToUpper(matches[2])) + s.withOpts[k] = strings.TrimSuffix(strings.TrimSpace(matches[3]), ",") + withOptsStr = withOptsStr[len(matches[0]):] } return nil } diff --git a/stmt_collection_parsing_test.go b/stmt_collection_parsing_test.go index d616226..772abab 100644 --- a/stmt_collection_parsing_test.go +++ b/stmt_collection_parsing_test.go @@ -40,7 +40,7 @@ func TestStmtCreateCollection_parse(t *testing.T) { return } if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + t.Fatalf("%s failed: %s\n%s", testName+"/"+testCase.name, err, testCase.sql) } stmt, ok := s.(*StmtCreateCollection) if !ok {