diff --git a/demo/usability_testing/data_mocker.py b/demo/usability_testing/data_mocker.py index f873daec9dc..6729ef0b70b 100644 --- a/demo/usability_testing/data_mocker.py +++ b/demo/usability_testing/data_mocker.py @@ -6,6 +6,7 @@ from typing import Optional import numpy as np import pandas as pd +import dateutil # to support save csv, and faster parquet, we don't use faker-cli directly @@ -146,8 +147,9 @@ def type_converter(sql_type): if sql_type in ['varchar', 'string']: # TODO(hw): set max length return 'pystr', {} + # timestamp should > 0 cuz tablet insert will check it, use utc if sql_type in ['date', 'timestamp']: - return 'iso8601', {} + return 'iso8601', {"tzinfo": dateutil.tz.UTC} if sql_type in ['float', 'double']: return 'pyfloat', ranges[sql_type] return 'py' + sql_type, {} diff --git a/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md b/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md index a44f699eed3..0113ef730b0 100644 --- a/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md +++ b/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md @@ -233,8 +233,8 @@ IndexOption ::= | ----------- | ------------------------------------------------------------ | ---------------------------------------------------- | ------------------------------------------------------------ | | `ABSOLUTE` | TTL的值代表过期时间。配置值为时间段如`100m, 12h, 1d, 365d`。最大可以配置的过期时间为`15768000m`(即30年) | 当记录过期时,会被淘汰。 | `INDEX(KEY=col1, TS=std_time, TTL_TYPE=absolute, TTL=100m)`
OpenMLDB将会删除100分钟之前的数据。 | | `LATEST` | TTL的值代表最大存活条数。即同一个索引下面,最大允许存在的数据条数。最大可以配置1000条 | 记录超过最大条数时,会被淘汰。 | `INDEX(KEY=col1, TS=std_time, TTL_TYPE=LATEST, TTL=10)`。OpenMLDB只会保留最近10条记录,删除以前的记录。 | -| `ABSORLAT` | 配置过期时间和最大存活条数。配置值是一个2元组,形如`(100m, 10), (1d, 1)`。最大可以配置`(15768000m, 1000)`。 | 当且仅当记录过期**或**记录超过最大条数时,才会淘汰。 | `INDEX(key=c1, ts=c6, ttl=(120min, 100), ttl_type=absorlat)`。当记录超过100条,**或者**当记录过期时,会被淘汰 | -| `ABSANDLAT` | 配置过期时间和最大存活条数。配置值是一个2元组,形如`(100m, 10), (1d, 1)`。最大可以配置`(15768000m, 1000)`。 | 当记录过期**且**记录超过最大条数时,记录会被淘汰。 | `INDEX(key=c1, ts=c6, ttl=(120min, 100), ttl_type=absandlat)`。当记录超过100条,**而且**记录过期时,会被淘汰 | +| `ABSORLAT` | 配置过期时间和最大存活条数。配置值是一个2元组,形如`(100m, 10), (1d, 1)`。最大可以配置`(15768000m, 1000)`。 | 当且仅当记录过期**或**记录超过最大条数时,才会淘汰。 | `INDEX(key=c1, ts=c6, ttl=(120m, 100), ttl_type=absorlat)`。当记录超过100条,**或者**当记录过期时,会被淘汰 | +| `ABSANDLAT` | 配置过期时间和最大存活条数。配置值是一个2元组,形如`(100m, 10), (1d, 1)`。最大可以配置`(15768000m, 1000)`。 | 当记录过期**且**记录超过最大条数时,记录会被淘汰。 | `INDEX(key=c1, ts=c6, ttl=(120m, 100), ttl_type=absandlat)`。当记录超过100条,**而且**记录过期时,会被淘汰 | ```{note} 最大过期时间和最大存活条数的限制,是出于性能考虑。如果你一定要配置更大的TTL值,请使用UpdateTTL来增大(可无视max限制),或者调整nameserver配置`absolute_ttl_max`和`latest_ttl_max`,重启生效。 diff --git a/docs/zh/openmldb_sql/dml/INSERT_STATEMENT.md b/docs/zh/openmldb_sql/dml/INSERT_STATEMENT.md index 6ecf98390a3..4799e557577 100644 --- a/docs/zh/openmldb_sql/dml/INSERT_STATEMENT.md +++ b/docs/zh/openmldb_sql/dml/INSERT_STATEMENT.md @@ -5,7 +5,7 @@ OpenMLDB 支持一次插入单行或多行数据。 ## syntax ``` -INSERT INFO tbl_name (column_list) VALUES (value_list) [, value_list ...] +INSERT [[OR] IGNORE] INTO tbl_name (column_list) VALUES (value_list) [, value_list ...] column_list: col_name [, col_name] ... @@ -16,6 +16,7 @@ value_list: **说明** - `INSERT` 只能用在在线模式 +- 默认`INSERT`不会去重,`INSERT OR IGNORE` 则可以忽略已存在于表中的数据,可以反复重试。 ## Examples diff --git a/docs/zh/openmldb_sql/dml/LOAD_DATA_STATEMENT.md b/docs/zh/openmldb_sql/dml/LOAD_DATA_STATEMENT.md index b3c7ffc55bf..d2c456b2913 100644 --- a/docs/zh/openmldb_sql/dml/LOAD_DATA_STATEMENT.md +++ b/docs/zh/openmldb_sql/dml/LOAD_DATA_STATEMENT.md @@ -58,6 +58,7 @@ FilePathPattern | load_mode | String | cluster | `load_mode='local'`仅支持从csv本地文件导入在线存储, 它通过本地客户端同步插入数据;
`load_mode='cluster'`仅支持集群版, 通过spark插入数据,支持同步或异步模式 | | thread | Integer | 1 | 仅在本地文件导入时生效,即`load_mode='local'`或者单机版,表示本地插入数据的线程数。 最大值为`50`。 | | writer_type | String | single | 集群版在线导入中插入数据的writer类型。可选值为`single`和`batch`,默认为`single`。`single`表示数据即读即写,节省内存。`batch`则是将整个rdd分区读完,确认数据类型有效性后,再写入集群,需要更多内存。在部分情况下,`batch`模式有利于筛选未写入的数据,方便重试这部分数据。 | +| put_if_absent | Boolean | false | 在源数据无重复行也不与表中已有数据重复时,可以使用此选项避免插入重复数据,特别是job失败后可以重试。等价于使用`INSERT OR IGNORE`。更多详情见下文。 | ```{note} 在集群版中,`LOAD DATA INFILE`语句会根据当前执行模式(execute_mode)决定将数据导入到在线或离线存储。单机版中没有存储区别,只会导入到在线存储中,同时也不支持`deep_copy`选项。 @@ -73,6 +74,7 @@ FilePathPattern 所以,请尽量使用绝对路径。单机测试中,本地文件用`file://`开头;生产环境中,推荐使用hdfs等文件系统。 ``` + ## SQL语句模版 ```sql @@ -158,3 +160,12 @@ null,null 第二行两列都是两个双引号。 - cluster模式默认quote为`"`,所以这一行是两个空字符串。 - local模式默认quote为`\0`,所以这一行两列都是两个双引号。local模式quote可以配置为`"`,但escape规则是`""`为单个`"`,和Spark不一致,具体见[issue3015](https://github.com/4paradigm/OpenMLDB/issues/3015)。 + +## PutIfAbsent说明 + +PutIfAbsent是一个特殊的选项,它可以避免插入重复数据,仅需一个配置,操作简单,特别适合load datajob失败后重试,等价于使用`INSERT OR IGNORE`。如果你想要导入的数据中存在重复,那么通过PutIfAbsent导入,会导致部分数据丢失。如果你需要保留重复数据,不应使用此选项,建议通过其他方式去重后再导入。 + +PutIfAbsent需要去重这一额外开销,所以,它的性能与去重的复杂度有关: + +- 表中只存在ts索引,且同一key+ts的数据量少于10k时(为了精确去重,在同一个key+ts下会逐行对比整行数据),PutIfAbsent的性能表现不会很差,通常导入时间在普通导入时间的2倍以内。 +- 表中如果存在time索引(ts列为空),或者ts索引同一key+ts的数据量大于100k时,PutIfAbsent的性能会很差,导入时间可能超过普通导入时间的10倍,无法正常使用。这样的数据条件下,更建议进行去重后再导入。 diff --git a/docs/zh/quickstart/beginner_must_read.md b/docs/zh/quickstart/beginner_must_read.md index 60522283942..ad403a6b423 100644 --- a/docs/zh/quickstart/beginner_must_read.md +++ b/docs/zh/quickstart/beginner_must_read.md @@ -69,6 +69,16 @@ OpenMLDB是在线离线存储计算分离的,所以,你需要明确自己导 关于如何设计你的数据流入流出,可参考[实时决策系统中 OpenMLDB 的常见架构整合方式](../tutorial/app_arch.md)。 +### 在线表 + +在线表是存在内存中的数据,同时也会使用硬盘进行备份恢复。在线表的数据,可以通过`select count(*) from t1`来检查条数,或者使用`show table status`来查看表状态(可能有一定延迟,可以稍等再查)。 + +在线表是可以有多个索引的,通过`desc `可以查看。写入一条数据时每个索引中都会写入一条,区别是各个索引的分类排序不同。但由于索引还有TTL淘汰机制,各个索引的数据量可能不一致。`select count(*) from t1`和`show table status`的结果是第一个索引的数据量,它并不代表其他索引的数据量。SQL查询会使用哪一个索引,是由SQL Engine选择的最优索引,可以通过SQL物理计划来查看。 + +建表时,可以指定索引,也可以不指定,不指定时,会默认创建一个索引。如果是默认索引,它无ts列(用当前time作为排序列,我们称为time索引)将会永不淘汰数据,可以以它为标准检查数据量是否准确,但这样的索引会占用太多的内存,目前也不可以删除第一条索引(计划未来支持),可以通过NS Client修改TTL淘汰数据,减少它的内存占用。 + +time索引(无ts的索引)还会影响PutIfAbsent导入。如果你的数据导入可能中途失败,无其他方法进行删除或去重,想要使用PutIfAbsent来进行导入重试时,请参考[PutIfAbsent说明](../openmldb_sql/dml/LOAD_DATA_STATEMENT.md#putifabsent说明)对自己的数据进行评估,避免PutIfAbsent效率太差。 + ## 源数据 ### LOAD DATA diff --git a/hybridse/include/node/node_manager.h b/hybridse/include/node/node_manager.h index 6949faf6f88..914feaca435 100644 --- a/hybridse/include/node/node_manager.h +++ b/hybridse/include/node/node_manager.h @@ -166,7 +166,7 @@ class NodeManager { SqlNode *MakeInsertTableNode(const std::string &db_name, const std::string &table_name, const ExprListNode *column_names, - const ExprListNode *values); + const ExprListNode *values, InsertStmt::InsertMode insert_mode); CreateStmt *MakeCreateTableNode(bool op_if_not_exist, const std::string &db_name, const std::string &table_name, diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index 8d641ad8283..801b9063aa4 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -1865,19 +1865,35 @@ class ColumnDefNode : public SqlNode { class InsertStmt : public SqlNode { public: + // ref zetasql ASTInsertStatement + enum InsertMode { + DEFAULT_MODE, // plain INSERT + REPLACE, // INSERT OR REPLACE + UPDATE, // INSERT OR UPDATE + IGNORE // INSERT OR IGNORE + }; + InsertStmt(const std::string &db_name, const std::string &table_name, const std::vector &columns, - const std::vector &values) + const std::vector &values, + InsertMode insert_mode) : SqlNode(kInsertStmt, 0, 0), db_name_(db_name), table_name_(table_name), columns_(columns), values_(values), - is_all_(columns.empty()) {} + is_all_(columns.empty()), + insert_mode_(insert_mode) {} - InsertStmt(const std::string &db_name, const std::string &table_name, const std::vector &values) - : SqlNode(kInsertStmt, 0, 0), db_name_(db_name), table_name_(table_name), values_(values), is_all_(true) {} + InsertStmt(const std::string &db_name, const std::string &table_name, const std::vector &values, + InsertMode insert_mode) + : SqlNode(kInsertStmt, 0, 0), + db_name_(db_name), + table_name_(table_name), + values_(values), + is_all_(true), + insert_mode_(insert_mode) {} void Print(std::ostream &output, const std::string &org_tab) const; const std::string db_name_; @@ -1885,6 +1901,7 @@ class InsertStmt : public SqlNode { const std::vector columns_; const std::vector values_; const bool is_all_; + const InsertMode insert_mode_; }; class StorageModeNode : public SqlNode { diff --git a/hybridse/src/node/node_manager.cc b/hybridse/src/node/node_manager.cc index 86d51249e19..192a5214f4a 100644 --- a/hybridse/src/node/node_manager.cc +++ b/hybridse/src/node/node_manager.cc @@ -792,9 +792,10 @@ AllNode *NodeManager::MakeAllNode(const std::string &relation_name, const std::s } SqlNode *NodeManager::MakeInsertTableNode(const std::string &db_name, const std::string &table_name, - const ExprListNode *columns_expr, const ExprListNode *values) { + const ExprListNode *columns_expr, const ExprListNode *values, + InsertStmt::InsertMode insert_mode) { if (nullptr == columns_expr) { - InsertStmt *node_ptr = new InsertStmt(db_name, table_name, values->children_); + InsertStmt *node_ptr = new InsertStmt(db_name, table_name, values->children_, insert_mode); return RegisterNode(node_ptr); } else { std::vector column_names; @@ -811,7 +812,7 @@ SqlNode *NodeManager::MakeInsertTableNode(const std::string &db_name, const std: } } } - InsertStmt *node_ptr = new InsertStmt(db_name, table_name, column_names, values->children_); + InsertStmt *node_ptr = new InsertStmt(db_name, table_name, column_names, values->children_, insert_mode); return RegisterNode(node_ptr); } } diff --git a/hybridse/src/node/sql_node_test.cc b/hybridse/src/node/sql_node_test.cc index e2938656dcc..f047a9cd737 100644 --- a/hybridse/src/node/sql_node_test.cc +++ b/hybridse/src/node/sql_node_test.cc @@ -308,7 +308,8 @@ TEST_F(SqlNodeTest, MakeInsertNodeTest) { value_expr_list->PushBack(value4); ExprListNode *insert_values = node_manager_->MakeExprList(); insert_values->PushBack(value_expr_list); - SqlNode *node_ptr = node_manager_->MakeInsertTableNode("", "t1", column_expr_list, insert_values); + SqlNode *node_ptr = node_manager_->MakeInsertTableNode("", "t1", column_expr_list, insert_values, + InsertStmt::InsertMode::DEFAULT_MODE); ASSERT_EQ(kInsertStmt, node_ptr->GetType()); InsertStmt *insert_stmt = dynamic_cast(node_ptr); diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index 5d9eb939113..53580b27874 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -1962,8 +1962,9 @@ base::Status ConvertInsertStatement(const zetasql::ASTInsertStatement* root, nod } CHECK_TRUE(nullptr == root->query(), common::kSqlAstError, "Un-support insert statement with query"); - CHECK_TRUE(zetasql::ASTInsertStatement::InsertMode::DEFAULT_MODE == root->insert_mode(), common::kSqlAstError, - "Un-support insert mode ", root->GetSQLForInsertMode()); + CHECK_TRUE(zetasql::ASTInsertStatement::InsertMode::DEFAULT_MODE == root->insert_mode() || + zetasql::ASTInsertStatement::InsertMode::IGNORE == root->insert_mode(), + common::kSqlAstError, "Un-support insert mode ", root->GetSQLForInsertMode()); CHECK_TRUE(nullptr == root->returning(), common::kSqlAstError, "Un-support insert statement with return clause currently", root->GetSQLForInsertMode()); CHECK_TRUE(nullptr == root->assert_rows_modified(), common::kSqlAstError, @@ -2000,8 +2001,8 @@ base::Status ConvertInsertStatement(const zetasql::ASTInsertStatement* root, nod if (names.size() == 2) { db_name = names[0]; } - *output = - dynamic_cast(node_manager->MakeInsertTableNode(db_name, table_name, column_list, rows)); + *output = dynamic_cast(node_manager->MakeInsertTableNode( + db_name, table_name, column_list, rows, static_cast(root->insert_mode()))); return base::Status::OK(); } base::Status ConvertDropStatement(const zetasql::ASTDropStatement* root, node::NodeManager* node_manager, diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/LoadDataPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/LoadDataPlan.scala index a04b46ab650..ec9946b839a 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/LoadDataPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/LoadDataPlan.scala @@ -55,16 +55,19 @@ object LoadDataPlan { loadDataSql) // write - logger.info("write data to storage {}, writer[mode {}], is deep? {}", storage, mode, deepCopy.toString) + logger.info("write data to storage {}, writer mode {}, is deep {}", storage, mode, deepCopy.toString) if (storage == "online") { // Import online data require(deepCopy && mode == "append", "import to online storage, can't do soft copy, and mode must be append") val writeType = extra.get("writer_type").get + val putIfAbsent = extra.get("put_if_absent").get.toBoolean + logger.info(s"online write type ${writeType}, put if absent ${putIfAbsent}") val writeOptions = Map( "db" -> db, "table" -> table, "zkCluster" -> ctx.getConf.openmldbZkCluster, "zkPath" -> ctx.getConf.openmldbZkRootPath, - "writerType" -> writeType + "writerType" -> writeType, + "putIfAbsent" -> putIfAbsent.toString ) df.write.options(writeOptions).format("openmldb").mode(mode).save() } else { // Import offline data diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala index 8bf6897d82f..6f3e5b78d40 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala @@ -247,16 +247,17 @@ object HybridseUtil { } // extra options for some special case - // only for PhysicalLoadDataNode var extraOptions: mutable.Map[String, String] = mutable.Map() + // only for PhysicalLoadDataNode extraOptions += ("deep_copy" -> parseOption(getOptionFromNode(node, "deep_copy"), "true", getBoolOrDefault)) - - // only for select into, "" means N/A - extraOptions += ("coalesce" -> parseOption(getOptionFromNode(node, "coalesce"), "0", getIntOrDefault)) - extraOptions += ("sql" -> parseOption(getOptionFromNode(node, "sql"), "", getStringOrDefault)) extraOptions += ("writer_type") -> parseOption(getOptionFromNode(node, "writer_type"), "single", getStringOrDefault) + extraOptions += ("sql" -> parseOption(getOptionFromNode(node, "sql"), "", getStringOrDefault)) + extraOptions += ("put_if_absent" -> parseOption(getOptionFromNode(node, "put_if_absent"), "false", + getBoolOrDefault)) + // only for select into, "" means N/A + extraOptions += ("coalesce" -> parseOption(getOptionFromNode(node, "coalesce"), "0", getIntOrDefault)) extraOptions += ("create_if_not_exists" -> parseOption(getOptionFromNode(node, "create_if_not_exists"), "true", getBoolOrDefault)) diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLConnection.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLConnection.java index 5383eaf246d..8259682755d 100644 --- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLConnection.java +++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLConnection.java @@ -82,7 +82,8 @@ public java.sql.Statement createStatement() throws SQLException { @Override public java.sql.PreparedStatement prepareStatement(String sql) throws SQLException { String lower = sql.toLowerCase(); - if (lower.startsWith("insert into")) { + // insert, insert or xxx + if (lower.startsWith("insert ")) { return client.getInsertPreparedStmt(this.defaultDatabase, sql); } else if (lower.startsWith("select")) { return client.getPreparedStatement(this.defaultDatabase, sql); diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/SdkOption.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/SdkOption.java index eca5289bf32..66d0d83bef9 100644 --- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/SdkOption.java +++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/SdkOption.java @@ -17,13 +17,14 @@ package com._4paradigm.openmldb.sdk; import lombok.Data; +import java.io.Serializable; import com._4paradigm.openmldb.BasicRouterOptions; import com._4paradigm.openmldb.SQLRouterOptions; import com._4paradigm.openmldb.StandaloneOptions; @Data -public class SdkOption { +public class SdkOption implements Serializable { // TODO(hw): set isClusterMode automatically private boolean isClusterMode = true; // options for cluster mode diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java index 6acefe8acff..ecc39b467c1 100644 --- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java +++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java @@ -319,7 +319,7 @@ public boolean execute() throws SQLException { // actually only one row boolean ok = router.ExecuteInsert(cache.getDatabase(), cache.getName(), cache.getTid(), cache.getPartitionNum(), - dimensions.array(), dimensions.capacity(), value.array(), value.capacity(), status); + dimensions.array(), dimensions.capacity(), value.array(), value.capacity(), cache.isPutIfAbsent(), status); // cleanup rows even if insert failed // we can't execute() again without set new row, so we must clean up here clearParameters(); @@ -381,7 +381,7 @@ public int[] executeBatch() throws SQLException { boolean ok = router.ExecuteInsert(cache.getDatabase(), cache.getName(), cache.getTid(), cache.getPartitionNum(), pair.getKey().array(), pair.getKey().capacity(), - pair.getValue().array(), pair.getValue().capacity(), status); + pair.getValue().array(), pair.getValue().capacity(), cache.isPutIfAbsent(), status); if (!ok) { // TODO(hw): may lost log, e.g. openmldb-batch online import in yarn mode? logger.warn(status.ToString()); diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementMeta.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementMeta.java index 448438e9d31..cf2bd05cb58 100644 --- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementMeta.java +++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementMeta.java @@ -31,6 +31,7 @@ public class InsertPreparedStatementMeta { private Set indexPos = new HashSet<>(); private Map> indexMap = new HashMap<>(); private Map defaultIndexValue = new HashMap<>(); + private boolean putIfAbsent; public InsertPreparedStatementMeta(String sql, NS.TableInfo tableInfo, SQLInsertRow insertRow) { this.sql = sql; @@ -51,6 +52,7 @@ public InsertPreparedStatementMeta(String sql, NS.TableInfo tableInfo, SQLInsert VectorUint32 idxArray = insertRow.GetHoleIdx(); buildHoleIdx(idxArray); idxArray.delete(); + putIfAbsent = insertRow.IsPutIfAbsent(); } private void buildIndex(NS.TableInfo tableInfo) { @@ -215,4 +217,8 @@ Map> getIndexMap() { Map getDefaultIndexValue() { return defaultIndexValue; } + + public boolean isPutIfAbsent() { + return putIfAbsent; + } } diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbConfig.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbConfig.java new file mode 100644 index 00000000000..7c0981d0a6c --- /dev/null +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbConfig.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com._4paradigm.openmldb.spark; + +import com._4paradigm.openmldb.sdk.SdkOption; + +import java.io.Serializable; + +import org.sparkproject.guava.base.Preconditions; + +// Must serializable +public class OpenmldbConfig implements Serializable { + public final static String DB = "db"; + public final static String TABLE = "table"; + public final static String ZK_CLUSTER = "zkCluster"; + public final static String ZK_PATH = "zkPath"; + + /* read&write */ + private String dbName; + private String tableName; + private SdkOption option = null; + + /* write */ + // single: insert when read one row + // batch: insert when commit(after read a whole partition) + private String writerType = "single"; + private int insertMemoryUsageLimit = 0; + private boolean putIfAbsent = false; + + public OpenmldbConfig() { + } + + public void setDB(String dbName) { + Preconditions.checkArgument(dbName != null && !dbName.isEmpty(), "db name must not be empty"); + this.dbName = dbName; + } + + public String getDB() { + return this.dbName; + } + + public void setTable(String tableName) { + Preconditions.checkArgument(tableName != null && !tableName.isEmpty(), "table name must not be empty"); + this.tableName = tableName; + } + + public String getTable() { + return this.tableName; + } + + public void setSdkOption(SdkOption option) { + this.option = option; + } + + public SdkOption getSdkOption() { + return this.option; + } + + public void setWriterType(String string) { + Preconditions.checkArgument(string.equals("single") || string.equals("batch"), + "writerType must be 'single' or 'batch'"); + this.writerType = string; + } + + public void setInsertMemoryUsageLimit(int int1) { + Preconditions.checkArgument(int1 >= 0, "insert_memory_usage_limit must be >= 0"); + this.insertMemoryUsageLimit = int1; + } + + public void setPutIfAbsent(Boolean valueOf) { + this.putIfAbsent = valueOf; + } + + public boolean isBatchWriter() { + return this.writerType.equals("batch"); + } + + public boolean putIfAbsent() { + return this.putIfAbsent; + } + + public int getInsertMemoryUsageLimit() { + return this.insertMemoryUsageLimit; + } + +} diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbSource.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbSource.java index 7e626f623ea..9dfe78f0197 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbSource.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbSource.java @@ -18,7 +18,6 @@ package com._4paradigm.openmldb.spark; import com._4paradigm.openmldb.sdk.SdkOption; -import com.google.common.base.Preconditions; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.expressions.Transform; @@ -29,32 +28,20 @@ import java.util.Map; public class OpenmldbSource implements TableProvider, DataSourceRegister { - private final String DB = "db"; - private final String TABLE = "table"; - private final String ZK_CLUSTER = "zkCluster"; - private final String ZK_PATH = "zkPath"; - private String dbName; - private String tableName; - private SdkOption option = null; - // single: insert when read one row - // batch: insert when commit(after read a whole partition) - private String writerType = "single"; - private int insertMemoryUsageLimit = 0; + private OpenmldbConfig config = new OpenmldbConfig(); @Override public StructType inferSchema(CaseInsensitiveStringMap options) { - Preconditions.checkNotNull(dbName = options.get(DB)); - Preconditions.checkNotNull(tableName = options.get(TABLE)); + config.setDB(options.get(OpenmldbConfig.DB)); + config.setTable(options.get(OpenmldbConfig.TABLE)); - String zkCluster = options.get(ZK_CLUSTER); - String zkPath = options.get(ZK_PATH); - Preconditions.checkNotNull(zkCluster); - Preconditions.checkNotNull(zkPath); - option = new SdkOption(); - option.setZkCluster(zkCluster); - option.setZkPath(zkPath); + SdkOption option = new SdkOption(); + option.setZkCluster(options.get(OpenmldbConfig.ZK_CLUSTER)); + option.setZkPath(options.get(OpenmldbConfig.ZK_PATH)); option.setLight(true); + config.setSdkOption(option); + String timeout = options.get("sessionTimeout"); if (timeout != null) { option.setSessionTimeout(Integer.parseInt(timeout)); @@ -69,18 +56,21 @@ public StructType inferSchema(CaseInsensitiveStringMap options) { } if (options.containsKey("writerType")) { - writerType = options.get("writerType"); + config.setWriterType(options.get("writerType")); + } + if (options.containsKey("putIfAbsent")) { + config.setPutIfAbsent(Boolean.valueOf(options.get("putIfAbsent"))); } if (options.containsKey("insert_memory_usage_limit")) { - insertMemoryUsageLimit = Integer.parseInt(options.get("insert_memory_usage_limit")); + config.setInsertMemoryUsageLimit(Integer.parseInt(options.get("insert_memory_usage_limit"))); } return null; } @Override public Table getTable(StructType schema, Transform[] partitioning, Map properties) { - return new OpenmldbTable(dbName, tableName, option, writerType, insertMemoryUsageLimit); + return new OpenmldbTable(config); } @Override diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbTable.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbTable.java index 481a9cc1f4c..e5cbcfe40ca 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbTable.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbTable.java @@ -22,10 +22,8 @@ import com._4paradigm.openmldb.sdk.SqlException; import com._4paradigm.openmldb.sdk.SqlExecutor; import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor; -import com._4paradigm.openmldb.spark.read.OpenmldbReadConfig; import com._4paradigm.openmldb.spark.read.OpenmldbScanBuilder; import com._4paradigm.openmldb.spark.write.OpenmldbWriteBuilder; -import com._4paradigm.openmldb.spark.write.OpenmldbWriteConfig; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.SupportsWrite; import org.apache.spark.sql.connector.catalog.TableCapability; @@ -45,40 +43,32 @@ import java.util.Set; public class OpenmldbTable implements SupportsWrite, SupportsRead { - private final String dbName; - private final String tableName; - private final SdkOption option; - private final String writerType; - private final int insertMemoryUsageLimit; - private SqlExecutor executor = null; + private OpenmldbConfig config; + private SqlExecutor executor; private Set capabilities; - public OpenmldbTable(String dbName, String tableName, SdkOption option, String writerType, int insertMemoryUsageLimit) { - this.dbName = dbName; - this.tableName = tableName; - this.option = option; - this.writerType = writerType; - this.insertMemoryUsageLimit = insertMemoryUsageLimit; + public OpenmldbTable(OpenmldbConfig config) { + this.config = config; try { - this.executor = new SqlClusterExecutor(option); + this.executor = new SqlClusterExecutor(config.getSdkOption()); // no need to check table exists, schema() will check it later } catch (SqlException e) { e.printStackTrace(); + throw new RuntimeException("conn openmldb failed", e); } // TODO: cache schema & delete executor? } @Override public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { - OpenmldbWriteConfig config = new OpenmldbWriteConfig(dbName, tableName, option, writerType, insertMemoryUsageLimit); return new OpenmldbWriteBuilder(config, info); } @Override public String name() { // TODO(hw): db? - return tableName; + return config.getTable(); } public static DataType sdkTypeToSparkType(int sqlType) { @@ -109,7 +99,7 @@ public static DataType sdkTypeToSparkType(int sqlType) { @Override public StructType schema() { try { - Schema schema = executor.getTableSchema(dbName, tableName); + Schema schema = executor.getTableSchema(config.getDB(), config.getTable()); List schemaList = schema.getColumnList(); StructField[] fields = new StructField[schemaList.size()]; for (int i = 0; i < schemaList.size(); i++) { @@ -136,7 +126,6 @@ public Set capabilities() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap caseInsensitiveStringMap) { - OpenmldbReadConfig config = new OpenmldbReadConfig(dbName, tableName, option); return new OpenmldbScanBuilder(config); } } diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReaderFactory.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReaderFactory.java index d5e435fc247..929d30b728e 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReaderFactory.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReaderFactory.java @@ -17,15 +17,16 @@ package com._4paradigm.openmldb.spark.read; +import com._4paradigm.openmldb.spark.OpenmldbConfig; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; public class OpenmldbPartitionReaderFactory implements PartitionReaderFactory { - private final OpenmldbReadConfig config; + private final OpenmldbConfig config; - public OpenmldbPartitionReaderFactory(OpenmldbReadConfig config) { + public OpenmldbPartitionReaderFactory(OpenmldbConfig config) { this.config = config; } diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbReadConfig.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbReadConfig.java deleted file mode 100644 index 91489888ba9..00000000000 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbReadConfig.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com._4paradigm.openmldb.spark.read; - -import com._4paradigm.openmldb.sdk.SdkOption; -import java.io.Serializable; - -// Must serializable -public class OpenmldbReadConfig implements Serializable { - public final String dbName, tableName, zkCluster, zkPath; - - public OpenmldbReadConfig(String dbName, String tableName, SdkOption option) { - this.dbName = dbName; - this.tableName = tableName; - this.zkCluster = option.getZkCluster(); - this.zkPath = option.getZkPath(); - // TODO(hw): other configs in SdkOption - } -} diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScan.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScan.java index fb7adb46b8e..4eeac9a6013 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScan.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScan.java @@ -17,6 +17,7 @@ package com._4paradigm.openmldb.spark.read; +import com._4paradigm.openmldb.spark.OpenmldbConfig; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -24,9 +25,9 @@ import org.apache.spark.sql.types.StructType; public class OpenmldbScan implements Scan, Batch { - private final OpenmldbReadConfig config; + private final OpenmldbConfig config; - public OpenmldbScan(OpenmldbReadConfig config) { + public OpenmldbScan(OpenmldbConfig config) { this.config = config; } diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScanBuilder.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScanBuilder.java index 2b500a6592e..de59a811f46 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScanBuilder.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScanBuilder.java @@ -17,13 +17,14 @@ package com._4paradigm.openmldb.spark.read; +import com._4paradigm.openmldb.spark.OpenmldbConfig; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; public class OpenmldbScanBuilder implements ScanBuilder { - private final OpenmldbReadConfig config; + private final OpenmldbConfig config; - public OpenmldbScanBuilder(OpenmldbReadConfig config) { + public OpenmldbScanBuilder(OpenmldbConfig config) { this.config = config; } diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbBatchWrite.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbBatchWrite.java index ca90a07d63a..d19fd9f6aeb 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbBatchWrite.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbBatchWrite.java @@ -17,6 +17,7 @@ package com._4paradigm.openmldb.spark.write; +import com._4paradigm.openmldb.spark.OpenmldbConfig; import org.apache.spark.sql.connector.write.BatchWrite; import org.apache.spark.sql.connector.write.DataWriterFactory; import org.apache.spark.sql.connector.write.LogicalWriteInfo; @@ -24,10 +25,10 @@ import org.apache.spark.sql.connector.write.WriterCommitMessage; public class OpenmldbBatchWrite implements BatchWrite { - private final OpenmldbWriteConfig config; + private final OpenmldbConfig config; private final LogicalWriteInfo info; - public OpenmldbBatchWrite(OpenmldbWriteConfig config, LogicalWriteInfo info) { + public OpenmldbBatchWrite(OpenmldbConfig config, LogicalWriteInfo info) { this.config = config; this.info = info; } diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataSingleWriter.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataSingleWriter.java index 843fb9a8da7..cc5f0150cc3 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataSingleWriter.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataSingleWriter.java @@ -17,8 +17,9 @@ package com._4paradigm.openmldb.spark.write; +import com._4paradigm.openmldb.spark.OpenmldbConfig; + import com._4paradigm.openmldb.sdk.Schema; -import com._4paradigm.openmldb.sdk.SdkOption; import com._4paradigm.openmldb.sdk.SqlException; import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor; import com.google.common.base.Preconditions; @@ -27,32 +28,26 @@ import org.apache.spark.sql.connector.write.WriterCommitMessage; import java.io.IOException; -import java.sql.Date; import java.sql.PreparedStatement; import java.sql.ResultSetMetaData; import java.sql.SQLException; -import java.sql.Timestamp; -import java.sql.Types; public class OpenmldbDataSingleWriter implements DataWriter { private final int partitionId; private final long taskId; private PreparedStatement preparedStatement = null; - public OpenmldbDataSingleWriter(OpenmldbWriteConfig config, int partitionId, long taskId) { + public OpenmldbDataSingleWriter(OpenmldbConfig config, int partitionId, long taskId) { try { - SdkOption option = new SdkOption(); - option.setZkCluster(config.zkCluster); - option.setZkPath(config.zkPath); - option.setLight(true); - SqlClusterExecutor executor = new SqlClusterExecutor(option); - String dbName = config.dbName; - String tableName = config.tableName; - executor.executeSQL(dbName, "SET @@insert_memory_usage_limit=" + config.insertMemoryUsageLimit); + SqlClusterExecutor executor = new SqlClusterExecutor(config.getSdkOption()); + String dbName = config.getDB(); + String tableName = config.getTable(); + executor.executeSQL(dbName, "SET @@insert_memory_usage_limit=" + config.getInsertMemoryUsageLimit()); Schema schema = executor.getTableSchema(dbName, tableName); // create insert placeholder - StringBuilder insert = new StringBuilder("insert into " + tableName + " values(?"); + String insert_part = config.putIfAbsent()? "insert or ignore into " : "insert into "; + StringBuilder insert = new StringBuilder(insert_part + tableName + " values(?"); for (int i = 1; i < schema.getColumnList().size(); i++) { insert.append(",?"); } @@ -60,6 +55,7 @@ public OpenmldbDataSingleWriter(OpenmldbWriteConfig config, int partitionId, lon preparedStatement = executor.getInsertPreparedStmt(dbName, insert.toString()); } catch (SQLException | SqlException e) { e.printStackTrace(); + throw new RuntimeException("create openmldb writer failed", e); } this.partitionId = partitionId; @@ -73,7 +69,12 @@ public void write(InternalRow record) throws IOException { ResultSetMetaData metaData = preparedStatement.getMetaData(); Preconditions.checkState(record.numFields() == metaData.getColumnCount()); OpenmldbDataWriter.addRow(record, preparedStatement); - preparedStatement.execute(); + // check return for put result + // you can cache failed rows and throw exception when commit/close, + // but it still may interrupt other writers(pending or slow writers) + if(!preparedStatement.execute()) { + throw new IOException("execute failed"); + } } catch (Exception e) { throw new IOException("write row to openmldb failed on " + record, e); } @@ -81,24 +82,13 @@ public void write(InternalRow record) throws IOException { @Override public WriterCommitMessage commit() throws IOException { - try { - preparedStatement.close(); - } catch (SQLException e) { - e.printStackTrace(); - throw new IOException("commit error", e); - } - // TODO(hw): need to return new WriterCommitMessageImpl(partitionId, taskId); ? + // no transaction, no commit return null; } @Override public void abort() throws IOException { - try { - preparedStatement.close(); - } catch (SQLException e) { - e.printStackTrace(); - throw new IOException("abort error", e); - } + // no transaction, no abort } @Override diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriter.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriter.java index e9ba0e30c5a..65bc2e5a457 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriter.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriter.java @@ -17,6 +17,8 @@ package com._4paradigm.openmldb.spark.write; +import com._4paradigm.openmldb.spark.OpenmldbConfig; + import com._4paradigm.openmldb.sdk.Schema; import com._4paradigm.openmldb.sdk.SdkOption; import com._4paradigm.openmldb.sdk.SqlException; @@ -39,20 +41,17 @@ public class OpenmldbDataWriter implements DataWriter { private final long taskId; private PreparedStatement preparedStatement = null; - public OpenmldbDataWriter(OpenmldbWriteConfig config, int partitionId, long taskId) { + public OpenmldbDataWriter(OpenmldbConfig config, int partitionId, long taskId) { try { - SdkOption option = new SdkOption(); - option.setZkCluster(config.zkCluster); - option.setZkPath(config.zkPath); - option.setLight(true); - SqlClusterExecutor executor = new SqlClusterExecutor(option); - String dbName = config.dbName; - String tableName = config.tableName; - executor.executeSQL(dbName, "SET @@insert_memory_usage_limit=" + config.insertMemoryUsageLimit); + SqlClusterExecutor executor = new SqlClusterExecutor(config.getSdkOption()); + String dbName = config.getDB(); + String tableName = config.getTable(); + executor.executeSQL(dbName, "SET @@insert_memory_usage_limit=" + config.getInsertMemoryUsageLimit()); Schema schema = executor.getTableSchema(dbName, tableName); // create insert placeholder - StringBuilder insert = new StringBuilder("insert into " + tableName + " values(?"); + String insert_part = config.putIfAbsent()? "insert or ignore into " : "insert into "; + StringBuilder insert = new StringBuilder(insert_part + tableName + " values(?"); for (int i = 1; i < schema.getColumnList().size(); i++) { insert.append(",?"); } @@ -60,6 +59,7 @@ public OpenmldbDataWriter(OpenmldbWriteConfig config, int partitionId, long task preparedStatement = executor.getInsertPreparedStmt(dbName, insert.toString()); } catch (SQLException | SqlException e) { e.printStackTrace(); + throw new RuntimeException("create openmldb data writer failed", e); } this.partitionId = partitionId; @@ -147,12 +147,7 @@ public WriterCommitMessage commit() throws IOException { @Override public void abort() throws IOException { - try { - preparedStatement.close(); - } catch (SQLException e) { - e.printStackTrace(); - throw new IOException("abort error", e); - } + // no transaction, no abort } @Override diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriterFactory.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriterFactory.java index 96e78979b2f..12cefb3928b 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriterFactory.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriterFactory.java @@ -17,20 +17,21 @@ package com._4paradigm.openmldb.spark.write; +import com._4paradigm.openmldb.spark.OpenmldbConfig; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.connector.write.DataWriterFactory; public class OpenmldbDataWriterFactory implements DataWriterFactory { - private final OpenmldbWriteConfig config; + private final OpenmldbConfig config; - public OpenmldbDataWriterFactory(OpenmldbWriteConfig config) { + public OpenmldbDataWriterFactory(OpenmldbConfig config) { this.config = config; } @Override public DataWriter createWriter(int partitionId, long taskId) { - if (!config.writerType.equals("batch")) { + if (!config.isBatchWriter()) { return new OpenmldbDataSingleWriter(config, partitionId, taskId); } return new OpenmldbDataWriter(config, partitionId, taskId); diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteBuilder.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteBuilder.java index a3c905b15c1..ccd588df0c4 100644 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteBuilder.java +++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteBuilder.java @@ -17,15 +17,16 @@ package com._4paradigm.openmldb.spark.write; +import com._4paradigm.openmldb.spark.OpenmldbConfig; import org.apache.spark.sql.connector.write.BatchWrite; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.connector.write.WriteBuilder; public class OpenmldbWriteBuilder implements WriteBuilder { - private final OpenmldbWriteConfig config; + private final OpenmldbConfig config; private final LogicalWriteInfo info; - public OpenmldbWriteBuilder(OpenmldbWriteConfig config, LogicalWriteInfo info) { + public OpenmldbWriteBuilder(OpenmldbConfig config, LogicalWriteInfo info) { this.config = config; this.info = info; } diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteConfig.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteConfig.java deleted file mode 100644 index 80007b14ae5..00000000000 --- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteConfig.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com._4paradigm.openmldb.spark.write; - -import com._4paradigm.openmldb.sdk.SdkOption; - -import java.io.Serializable; - -// Must serializable -public class OpenmldbWriteConfig implements Serializable { - public final String dbName, tableName, zkCluster, zkPath, writerType; - public final int insertMemoryUsageLimit; - - public OpenmldbWriteConfig(String dbName, String tableName, SdkOption option, String writerType, int insertMemoryUsageLimit) { - this.dbName = dbName; - this.tableName = tableName; - this.zkCluster = option.getZkCluster(); - this.zkPath = option.getZkPath(); - this.writerType = writerType; - this.insertMemoryUsageLimit = insertMemoryUsageLimit; - // TODO(hw): other configs in SdkOption - } -} diff --git a/java/openmldb-spark-connector/src/main/scala/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReader.scala b/java/openmldb-spark-connector/src/main/scala/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReader.scala index d8eeb89e7ab..86921d0f4d5 100644 --- a/java/openmldb-spark-connector/src/main/scala/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReader.scala +++ b/java/openmldb-spark-connector/src/main/scala/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReader.scala @@ -1,5 +1,6 @@ package com._4paradigm.openmldb.spark.read +import com._4paradigm.openmldb.spark.OpenmldbConfig import com._4paradigm.openmldb.sdk.{Schema, SdkOption} import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor import org.apache.spark.sql.catalyst.InternalRow @@ -8,15 +9,10 @@ import org.apache.spark.unsafe.types.UTF8String import java.sql.Types -class OpenmldbPartitionReader(config: OpenmldbReadConfig) extends PartitionReader[InternalRow] { - - val option = new SdkOption - option.setZkCluster(config.zkCluster) - option.setZkPath(config.zkPath) - option.setLight(true) - val executor = new SqlClusterExecutor(option) - val dbName: String = config.dbName - val tableName: String = config.tableName +class OpenmldbPartitionReader(config: OpenmldbConfig) extends PartitionReader[InternalRow] { + val executor = new SqlClusterExecutor(config.getSdkOption) + val dbName: String = config.getDB + val tableName: String = config.getTable val schema: Schema = executor.getTableSchema(dbName, tableName) executor.executeSQL(dbName, "SET @@execute_mode='online'") diff --git a/src/client/tablet_client.cc b/src/client/tablet_client.cc index 54b2a8c9cec..f445cc1791c 100644 --- a/src/client/tablet_client.cc +++ b/src/client/tablet_client.cc @@ -203,19 +203,21 @@ bool TabletClient::UpdateTableMetaForAddField(uint32_t tid, const std::vector>& dimensions, - int memory_usage_limit) { + int memory_usage_limit, bool put_if_absent) { + ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension> pb_dimensions; for (size_t i = 0; i < dimensions.size(); i++) { ::openmldb::api::Dimension* d = pb_dimensions.Add(); d->set_key(dimensions[i].first); d->set_idx(dimensions[i].second); } - return Put(tid, pid, time, base::Slice(value), &pb_dimensions, memory_usage_limit); + + return Put(tid, pid, time, base::Slice(value), &pb_dimensions, memory_usage_limit, put_if_absent); } base::Status TabletClient::Put(uint32_t tid, uint32_t pid, uint64_t time, const base::Slice& value, ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension>* dimensions, - int memory_usage_limit) { + int memory_usage_limit, bool put_if_absent) { ::openmldb::api::PutRequest request; if (memory_usage_limit < 0 || memory_usage_limit > 100) { return {base::ReturnCode::kError, absl::StrCat("invalid memory_usage_limit ", memory_usage_limit)}; @@ -227,6 +229,7 @@ base::Status TabletClient::Put(uint32_t tid, uint32_t pid, uint64_t time, const request.set_tid(tid); request.set_pid(pid); request.mutable_dimensions()->Swap(dimensions); + request.set_put_if_absent(put_if_absent); ::openmldb::api::PutResponse response; auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::Put, &request, &response, FLAGS_request_timeout_ms, 1); diff --git a/src/client/tablet_client.h b/src/client/tablet_client.h index b4866b77618..19579f90c5c 100644 --- a/src/client/tablet_client.h +++ b/src/client/tablet_client.h @@ -76,20 +76,18 @@ class TabletClient : public Client { base::Status Put(uint32_t tid, uint32_t pid, uint64_t time, const std::string& value, const std::vector>& dimensions, - int memory_usage_limit = 0); + int memory_usage_limit = 0, bool put_if_absent = false); base::Status Put(uint32_t tid, uint32_t pid, uint64_t time, const base::Slice& value, ::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension>* dimensions, - int memory_usage_limit = 0); + int memory_usage_limit = 0, bool put_if_absent = false); bool Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, std::string& value, // NOLINT uint64_t& ts, // NOLINT std::string& msg); // NOLINT bool Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, const std::string& idx_name, - std::string& value, // NOLINT - uint64_t& ts, // NOLINT - std::string& msg); // NOLINT + std::string& value, uint64_t& ts, std::string& msg); // NOLINT bool Delete(uint32_t tid, uint32_t pid, const std::string& pk, const std::string& idx_name, std::string& msg); // NOLINT diff --git a/src/cmd/openmldb.cc b/src/cmd/openmldb.cc index 8d0d9b692f5..53da31ad634 100644 --- a/src/cmd/openmldb.cc +++ b/src/cmd/openmldb.cc @@ -438,7 +438,7 @@ std::shared_ptr<::openmldb::client::TabletClient> GetTabletClient(const ::openml void HandleNSClientSetTTL(const std::vector& parts, ::openmldb::client::NsClient* client) { if (parts.size() < 4) { - std::cout << "bad setttl format, eg settl t1 absolute 10" << std::endl; + std::cout << "bad setttl format, eg settl t1 absolute 10 [index0]" << std::endl; return; } std::string index_name; @@ -1307,14 +1307,14 @@ void HandleNSGet(const std::vector& parts, ::openmldb::client::NsCl if (parts.size() < 4) { std::cout << "get format error. eg: get table_name key ts | get " "table_name key idx_name ts | get table_name=xxx key=xxx " - "index_name=xxx ts=xxx ts_name=xxx " + "index_name=xxx ts=xxx" << std::endl; return; } std::map parameter_map; if (!GetParameterMap("table_name", parts, "=", parameter_map)) { std::cout << "get format error. eg: get table_name=xxx key=xxx " - "index_name=xxx ts=xxx ts_name=xxx " + "index_name=xxx ts=xxx" << std::endl; return; } @@ -1382,7 +1382,7 @@ void HandleNSGet(const std::vector& parts, ::openmldb::client::NsCl return; } ::openmldb::codec::SDKCodec codec(tables[0]); - bool no_schema = tables[0].column_desc_size() == 0 && tables[0].column_desc_size() == 0; + bool no_schema = tables[0].column_desc_size() == 0; if (no_schema) { std::string value; uint64_t ts = 0; @@ -2459,7 +2459,7 @@ void HandleNSClientHelp(const std::vector& parts, ::openmldb::clien printf("ex:man create\n"); } else if (parts[1] == "setttl") { printf("desc: set table ttl \n"); - printf("usage: setttl table_name ttl_type ttl [ts_name]\n"); + printf("usage: setttl table_name ttl_type ttl [index_name], abs ttl unit is minute\n"); printf("ex: setttl t1 absolute 10\n"); printf("ex: setttl t2 latest 5\n"); printf("ex: setttl t3 latest 5 ts1\n"); diff --git a/src/proto/tablet.proto b/src/proto/tablet.proto index 2c7a038960b..a18714b2ae1 100755 --- a/src/proto/tablet.proto +++ b/src/proto/tablet.proto @@ -195,6 +195,7 @@ message PutRequest { repeated TSDimension ts_dimensions = 7 [deprecated = true]; optional uint32 format_version = 8 [default = 0, deprecated = true]; optional uint32 memory_limit = 9; + optional bool put_if_absent = 10 [default = false]; } message PutResponse { diff --git a/src/sdk/node_adapter_test.cc b/src/sdk/node_adapter_test.cc index e09758b07cd..70c35ff7d9c 100644 --- a/src/sdk/node_adapter_test.cc +++ b/src/sdk/node_adapter_test.cc @@ -64,7 +64,7 @@ void CheckTablePartition(const ::openmldb::nameserver::TableInfo& table_info, if (table_partition.partition_meta(pos).is_leader()) { ASSERT_EQ(table_partition.partition_meta(pos).endpoint(), leader); } else { - ASSERT_EQ(follower.count(table_partition.partition_meta(pos).endpoint()), 1); + ASSERT_EQ(follower.count(table_partition.partition_meta(pos).endpoint()), (std::size_t)1); } } } diff --git a/src/sdk/sql_cache.h b/src/sdk/sql_cache.h index a326437c10f..1fe0b346fa2 100644 --- a/src/sdk/sql_cache.h +++ b/src/sdk/sql_cache.h @@ -54,26 +54,28 @@ class InsertSQLCache : public SQLCache { InsertSQLCache(const std::shared_ptr<::openmldb::nameserver::TableInfo>& table_info, const std::shared_ptr<::hybridse::sdk::Schema>& column_schema, DefaultValueMap default_map, - uint32_t str_length, std::vector hole_idx_arr) + uint32_t str_length, std::vector hole_idx_arr, bool put_if_absent) : SQLCache(table_info->db(), table_info->tid(), table_info->name()), table_info_(table_info), column_schema_(column_schema), default_map_(std::move(default_map)), str_length_(str_length), - hole_idx_arr_(std::move(hole_idx_arr)) {} + hole_idx_arr_(std::move(hole_idx_arr)), + put_if_absent_(put_if_absent) {} std::shared_ptr<::openmldb::nameserver::TableInfo> GetTableInfo() { return table_info_; } std::shared_ptr<::hybridse::sdk::Schema> GetSchema() const { return column_schema_; } uint32_t GetStrLength() const { return str_length_; } const DefaultValueMap& GetDefaultValue() const { return default_map_; } const std::vector& GetHoleIdxArr() const { return hole_idx_arr_; } - + const bool IsPutIfAbsent() const { return put_if_absent_; } private: std::shared_ptr<::openmldb::nameserver::TableInfo> table_info_; std::shared_ptr<::hybridse::sdk::Schema> column_schema_; const DefaultValueMap default_map_; const uint32_t str_length_; const std::vector hole_idx_arr_; + const bool put_if_absent_; }; class RouterSQLCache : public SQLCache { diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index 7c5a98814b9..bdad16cfc8c 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -455,39 +455,40 @@ std::shared_ptr SQLClusterRouter::GetInsertRow(const std::string& *status = {}; return std::make_shared(insert_cache->GetTableInfo(), insert_cache->GetSchema(), insert_cache->GetDefaultValue(), insert_cache->GetStrLength(), - insert_cache->GetHoleIdxArr()); + insert_cache->GetHoleIdxArr(), insert_cache->IsPutIfAbsent()); } } std::shared_ptr<::openmldb::nameserver::TableInfo> table_info; DefaultValueMap default_map; uint32_t str_length = 0; std::vector stmt_column_idx_arr; - if (!GetInsertInfo(db, sql, status, &table_info, &default_map, &str_length, &stmt_column_idx_arr)) { + bool put_if_absent = false; + if (!GetInsertInfo(db, sql, status, &table_info, &default_map, &str_length, &stmt_column_idx_arr, &put_if_absent)) { SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "get insert information failed"); return {}; } auto schema = openmldb::schema::SchemaAdapter::ConvertSchema(table_info->column_desc()); - auto insert_cache = - std::make_shared(table_info, schema, default_map, str_length, - SQLInsertRow::GetHoleIdxArr(default_map, stmt_column_idx_arr, schema)); + auto insert_cache = std::make_shared( + table_info, schema, default_map, str_length, + SQLInsertRow::GetHoleIdxArr(default_map, stmt_column_idx_arr, schema), put_if_absent); SetCache(db, sql, hybridse::vm::kBatchMode, insert_cache); *status = {}; return std::make_shared(insert_cache->GetTableInfo(), insert_cache->GetSchema(), insert_cache->GetDefaultValue(), insert_cache->GetStrLength(), - insert_cache->GetHoleIdxArr()); + insert_cache->GetHoleIdxArr(), insert_cache->IsPutIfAbsent()); } bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status, std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, std::vector* default_maps, - std::vector* str_lengths) { + std::vector* str_lengths, bool* put_if_absent) { RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr"); // TODO(hw): return status? RET_FALSE_IF_NULL_AND_WARN(table_info, "output table_info is nullptr"); RET_FALSE_IF_NULL_AND_WARN(default_maps, "output default_maps is nullptr"); RET_FALSE_IF_NULL_AND_WARN(str_lengths, "output str_lengths is nullptr"); - + RET_FALSE_IF_NULL_AND_WARN(put_if_absent, "output put_if_absent is nullptr"); ::hybridse::node::NodeManager nm; ::hybridse::plan::PlanNodeList plans; bool ok = GetSQLPlan(sql, &nm, &plans); @@ -506,6 +507,7 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s SET_STATUS_AND_WARN(status, StatusCode::kPlanError, "insert stmt is null"); return false; } + *put_if_absent = insert_stmt->insert_mode_ == ::hybridse::node::InsertStmt::IGNORE; std::string db_name; if (!insert_stmt->db_name_.empty()) { db_name = insert_stmt->db_name_; @@ -576,7 +578,7 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s bool SQLClusterRouter::GetInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status, std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, DefaultValueMap* default_map, uint32_t* str_length, - std::vector* stmt_column_idx_in_table) { + std::vector* stmt_column_idx_in_table, bool* put_if_absent) { RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr"); RET_FALSE_IF_NULL_AND_WARN(table_info, "output table_info is nullptr"); RET_FALSE_IF_NULL_AND_WARN(default_map, "output default_map is nullptr"); @@ -636,6 +638,7 @@ bool SQLClusterRouter::GetInsertInfo(const std::string& db, const std::string& s SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "get default value map of " + sql + " failed"); return false; } + *put_if_absent = insert_stmt->insert_mode_ == ::hybridse::node::InsertStmt::IGNORE; return true; } @@ -771,23 +774,24 @@ std::shared_ptr SQLClusterRouter::GetInsertRows(const std::string status->SetOK(); return std::make_shared(insert_cache->GetTableInfo(), insert_cache->GetSchema(), insert_cache->GetDefaultValue(), insert_cache->GetStrLength(), - insert_cache->GetHoleIdxArr()); + insert_cache->GetHoleIdxArr(), insert_cache->IsPutIfAbsent()); } } std::shared_ptr<::openmldb::nameserver::TableInfo> table_info; DefaultValueMap default_map; uint32_t str_length = 0; std::vector stmt_column_idx_arr; - if (!GetInsertInfo(db, sql, status, &table_info, &default_map, &str_length, &stmt_column_idx_arr)) { + bool put_if_absent = false; + if (!GetInsertInfo(db, sql, status, &table_info, &default_map, &str_length, &stmt_column_idx_arr, &put_if_absent)) { return {}; } auto col_schema = openmldb::schema::SchemaAdapter::ConvertSchema(table_info->column_desc()); - insert_cache = - std::make_shared(table_info, col_schema, default_map, str_length, - SQLInsertRow::GetHoleIdxArr(default_map, stmt_column_idx_arr, col_schema)); + insert_cache = std::make_shared( + table_info, col_schema, default_map, str_length, + SQLInsertRow::GetHoleIdxArr(default_map, stmt_column_idx_arr, col_schema), put_if_absent); SetCache(db, sql, hybridse::vm::kBatchMode, insert_cache); return std::make_shared(table_info, insert_cache->GetSchema(), default_map, str_length, - insert_cache->GetHoleIdxArr()); + insert_cache->GetHoleIdxArr(), insert_cache->IsPutIfAbsent()); } bool SQLClusterRouter::ExecuteDDL(const std::string& db, const std::string& sql, hybridse::sdk::Status* status) { @@ -1303,7 +1307,8 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s std::shared_ptr<::openmldb::nameserver::TableInfo> table_info; std::vector default_maps; std::vector str_lengths; - if (!GetMultiRowInsertInfo(db, sql, status, &table_info, &default_maps, &str_lengths)) { + bool put_if_absent; + if (!GetMultiRowInsertInfo(db, sql, status, &table_info, &default_maps, &str_lengths, &put_if_absent)) { CODE_PREPEND_AND_WARN(status, StatusCode::kCmdError, "Fail to get insert info"); return false; } @@ -1318,7 +1323,7 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s } std::vector fails; for (size_t i = 0; i < default_maps.size(); i++) { - auto row = std::make_shared(table_info, schema, default_maps[i], str_lengths[i]); + auto row = std::make_shared(table_info, schema, default_maps[i], str_lengths[i], put_if_absent); if (!row) { LOG(WARNING) << "fail to parse row[" << i << "]"; fails.push_back(i); @@ -1368,7 +1373,7 @@ bool SQLClusterRouter::PutRow(uint32_t tid, const std::shared_ptr& DLOG(INFO) << "put data to endpoint " << client->GetEndpoint() << " with dimensions size " << kv.second.size(); auto ret = client->Put(tid, pid, cur_ts, row->GetRow(), kv.second, - insert_memory_usage_limit_.load(std::memory_order_relaxed)); + insert_memory_usage_limit_.load(std::memory_order_relaxed), row->IsPutIfAbsent()); if (!ret.OK()) { if (RevertPut(row->GetTableInfo(), pid, dimensions, cur_ts, base::Slice(row->GetRow()), tablets).IsOK()) { @@ -1448,8 +1453,8 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s } bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& name, int tid, int partition_num, - hybridse::sdk::ByteArrayPtr dimension, int dimension_len, - hybridse::sdk::ByteArrayPtr value, int len, hybridse::sdk::Status* status) { + hybridse::sdk::ByteArrayPtr dimension, int dimension_len, + hybridse::sdk::ByteArrayPtr value, int len, bool put_if_absent, hybridse::sdk::Status* status) { RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr"); if (dimension == nullptr || dimension_len <= 0 || value == nullptr || len <= 0 || partition_num <= 0) { *status = {StatusCode::kCmdError, "invalid parameter"}; @@ -1491,7 +1496,7 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& n DLOG(INFO) << "put data to endpoint " << client->GetEndpoint() << " with dimensions size " << kv.second.size(); auto ret = client->Put(tid, pid, cur_ts, row_value, &kv.second, - insert_memory_usage_limit_.load(std::memory_order_relaxed)); + insert_memory_usage_limit_.load(std::memory_order_relaxed), put_if_absent); if (!ret.OK()) { SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "INSERT failed, tid " + std::to_string(tid) + diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h index 502ad07dab6..0b9f6cca272 100644 --- a/src/sdk/sql_cluster_router.h +++ b/src/sdk/sql_cluster_router.h @@ -87,7 +87,7 @@ class SQLClusterRouter : public SQLRouter { bool ExecuteInsert(const std::string& db, const std::string& name, int tid, int partition_num, hybridse::sdk::ByteArrayPtr dimension, int dimension_len, - hybridse::sdk::ByteArrayPtr value, int len, hybridse::sdk::Status* status) override; + hybridse::sdk::ByteArrayPtr value, int len, bool put_if_absent, hybridse::sdk::Status* status) override; bool ExecuteDelete(std::shared_ptr row, hybridse::sdk::Status* status) override; @@ -316,10 +316,11 @@ class SQLClusterRouter : public SQLRouter { bool GetInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status, std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, DefaultValueMap* default_map, - uint32_t* str_length, std::vector* stmt_column_idx_in_table); + uint32_t* str_length, std::vector* stmt_column_idx_in_table, bool* put_if_absent); bool GetMultiRowInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status, std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, - std::vector* default_maps, std::vector* str_lengths); + std::vector* default_maps, std::vector* str_lengths, + bool* put_if_absent); DefaultValueMap GetDefaultMap(const std::shared_ptr<::openmldb::nameserver::TableInfo>& table_info, const std::map& column_map, ::hybridse::node::ExprListNode* row, diff --git a/src/sdk/sql_insert_row.cc b/src/sdk/sql_insert_row.cc index a2d44571be2..492bb80e49b 100644 --- a/src/sdk/sql_insert_row.cc +++ b/src/sdk/sql_insert_row.cc @@ -29,33 +29,35 @@ namespace sdk { SQLInsertRows::SQLInsertRows(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, std::shared_ptr schema, DefaultValueMap default_map, - uint32_t default_str_length, const std::vector& hole_idx_arr) + uint32_t default_str_length, const std::vector& hole_idx_arr, bool put_if_absent) : table_info_(std::move(table_info)), schema_(std::move(schema)), default_map_(std::move(default_map)), default_str_length_(default_str_length), - hole_idx_arr_(hole_idx_arr) {} + hole_idx_arr_(hole_idx_arr), + put_if_absent_(put_if_absent) {} std::shared_ptr SQLInsertRows::NewRow() { if (!rows_.empty() && !rows_.back()->IsComplete()) { return {}; } - std::shared_ptr row = - std::make_shared(table_info_, schema_, default_map_, default_str_length_, hole_idx_arr_); + std::shared_ptr row = std::make_shared( + table_info_, schema_, default_map_, default_str_length_, hole_idx_arr_, put_if_absent_); rows_.push_back(row); return row; } SQLInsertRow::SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, std::shared_ptr schema, DefaultValueMap default_map, - uint32_t default_string_length) + uint32_t default_string_length, bool put_if_absent) : table_info_(table_info), schema_(std::move(schema)), default_map_(std::move(default_map)), default_string_length_(default_string_length), rb_(table_info->column_desc()), val_(), - str_size_(0) { + str_size_(0), + put_if_absent_(put_if_absent) { std::map column_name_map; for (int idx = 0; idx < table_info_->column_desc_size(); idx++) { column_name_map.emplace(table_info_->column_desc(idx).name(), idx); @@ -81,8 +83,9 @@ SQLInsertRow::SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> ta SQLInsertRow::SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, std::shared_ptr schema, DefaultValueMap default_map, - uint32_t default_str_length, std::vector hole_idx_arr) - : SQLInsertRow(std::move(table_info), std::move(schema), std::move(default_map), default_str_length) { + uint32_t default_str_length, std::vector hole_idx_arr, bool put_if_absent) + : SQLInsertRow(std::move(table_info), std::move(schema), std::move(default_map), default_str_length, + put_if_absent) { hole_idx_arr_ = std::move(hole_idx_arr); } diff --git a/src/sdk/sql_insert_row.h b/src/sdk/sql_insert_row.h index ded1c824e19..af18891587f 100644 --- a/src/sdk/sql_insert_row.h +++ b/src/sdk/sql_insert_row.h @@ -103,12 +103,13 @@ class DefaultValueContainer { class SQLInsertRow { public: + // for raw insert sql(no hole) SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, std::shared_ptr schema, DefaultValueMap default_map, - uint32_t default_str_length); + uint32_t default_str_length, bool put_if_absent); SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, std::shared_ptr schema, DefaultValueMap default_map, - uint32_t default_str_length, std::vector hole_idx_arr); + uint32_t default_str_length, std::vector hole_idx_arr, bool put_if_absent); ~SQLInsertRow() = default; bool Init(int str_length); bool AppendBool(bool val); @@ -155,6 +156,10 @@ class SQLInsertRow { return *table_info_; } + bool IsPutIfAbsent() const { + return put_if_absent_; + } + private: bool MakeDefault(); void PackDimension(const std::string& val); @@ -175,13 +180,14 @@ class SQLInsertRow { ::openmldb::codec::RowBuilder rb_; std::string val_; uint32_t str_size_; + bool put_if_absent_; }; class SQLInsertRows { public: SQLInsertRows(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, std::shared_ptr schema, DefaultValueMap default_map, uint32_t str_size, - const std::vector& hole_idx_arr); + const std::vector& hole_idx_arr, bool put_if_absent); ~SQLInsertRows() = default; std::shared_ptr NewRow(); inline uint32_t GetCnt() { return rows_.size(); } @@ -200,6 +206,7 @@ class SQLInsertRows { DefaultValueMap default_map_; uint32_t default_str_length_; std::vector hole_idx_arr_; + bool put_if_absent_; std::vector> rows_; }; diff --git a/src/sdk/sql_router.h b/src/sdk/sql_router.h index 4317d435f8c..07b2e3b7734 100644 --- a/src/sdk/sql_router.h +++ b/src/sdk/sql_router.h @@ -130,7 +130,7 @@ class SQLRouter { virtual bool ExecuteInsert(const std::string& db, const std::string& name, int tid, int partition_num, hybridse::sdk::ByteArrayPtr dimension, int dimension_len, - hybridse::sdk::ByteArrayPtr value, int len, hybridse::sdk::Status* status) = 0; + hybridse::sdk::ByteArrayPtr value, int len, bool put_if_absent, hybridse::sdk::Status* status) = 0; virtual bool ExecuteDelete(std::shared_ptr row, hybridse::sdk::Status* status) = 0; diff --git a/src/storage/aggregator.cc b/src/storage/aggregator.cc index 7814c687be5..4615c87bc20 100644 --- a/src/storage/aggregator.cc +++ b/src/storage/aggregator.cc @@ -641,9 +641,9 @@ bool Aggregator::FlushAggrBuffer(const std::string& key, const std::string& filt auto dimension = entry.add_dimensions(); dimension->set_idx(aggr_index_pos_); dimension->set_key(key); - bool ok = aggr_table_->Put(time, entry.value(), entry.dimensions()); - if (!ok) { - PDLOG(ERROR, "Aggregator put failed"); + auto st = aggr_table_->Put(time, entry.value(), entry.dimensions()); + if (!st.ok()) { + LOG(ERROR) << "Aggregator put failed: " << st.ToString(); return false; } entry.set_pk(key); diff --git a/src/storage/disk_table.cc b/src/storage/disk_table.cc index b41c9f8fd3c..af35ab9a170 100644 --- a/src/storage/disk_table.cc +++ b/src/storage/disk_table.cc @@ -227,7 +227,8 @@ bool DiskTable::Put(const std::string& pk, uint64_t time, const char* data, uint } } -bool DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions) { +absl::Status DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions, bool put_if_absent) { + // disk table will update if key-time is the same, so no need to handle put_if_absent const int8_t* data = reinterpret_cast(value.data()); std::string uncompress_data; if (GetCompressType() == openmldb::type::kSnappy) { @@ -237,15 +238,14 @@ bool DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& d uint8_t version = codec::RowView::GetSchemaVersion(data); auto decoder = GetVersionDecoder(version); if (decoder == nullptr) { - PDLOG(WARNING, "invalid schema version %u, tid %u pid %u", version, id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid schema version ", version)); } rocksdb::WriteBatch batch; for (auto it = dimensions.begin(); it != dimensions.end(); ++it) { auto index_def = table_index_.GetIndex(it->idx()); if (!index_def || !index_def->IsReady()) { - PDLOG(WARNING, "failed putting key %s to dimension %u in table tid %u pid %u", it->key().c_str(), - it->idx(), id_, pid_); + PDLOG(WARNING, "failed putting key %s to dimension %u in table tid %u pid %u", it->key().c_str(), it->idx(), + id_, pid_); } int32_t inner_pos = table_index_.GetInnerIndexPos(it->idx()); auto inner_index = table_index_.GetInnerIndex(inner_pos); @@ -256,12 +256,10 @@ bool DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& d if (ts_col->IsAutoGenTs()) { ts = time; } else if (decoder->GetInteger(data, ts_col->GetId(), ts_col->GetType(), &ts) != 0) { - PDLOG(WARNING, "get ts failed. tid %u pid %u", id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": get ts failed")); } if (ts < 0) { - PDLOG(WARNING, "ts %ld is negative. tid %u pid %u", ts, id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": ts is negative ", ts)); } if (inner_index->GetIndex().size() > 1) { combine_key = CombineKeyTs(it->key(), ts, ts_col->GetId()); @@ -275,10 +273,9 @@ bool DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& d auto s = db_->Write(write_opts_, &batch); if (s.ok()) { offset_.fetch_add(1, std::memory_order_relaxed); - return true; + return absl::OkStatus(); } else { - DEBUGLOG("Put failed. tid %u pid %u msg %s", id_, pid_, s.ToString().c_str()); - return false; + return absl::InternalError(absl::StrCat(id_, ".", pid_, ": ", s.ToString())); } } diff --git a/src/storage/disk_table.h b/src/storage/disk_table.h index be549d0c2cd..7b471bac45e 100644 --- a/src/storage/disk_table.h +++ b/src/storage/disk_table.h @@ -21,6 +21,7 @@ #include #include #include + #include "base/slice.h" #include "base/status.h" #include "common/timer.h" @@ -102,7 +103,7 @@ class AbsoluteTTLCompactionFilter : public rocksdb::CompactionFilter { return false; } uint32_t ts_idx = *((uint32_t*)(key.data() + key.size() - TS_LEN - // NOLINT - TS_POS_LEN)); + TS_POS_LEN)); bool has_found = false; for (const auto& index : indexs) { auto ts_col = index->GetTsColumn(); @@ -110,7 +111,7 @@ class AbsoluteTTLCompactionFilter : public rocksdb::CompactionFilter { return false; } if (ts_col->GetId() == ts_idx && - index->GetTTL()->ttl_type == openmldb::storage::TTLType::kAbsoluteTime) { + index->GetTTL()->ttl_type == openmldb::storage::TTLType::kAbsoluteTime) { real_ttl = index->GetTTL()->abs_ttl; has_found = true; break; @@ -172,7 +173,8 @@ class DiskTable : public Table { bool Put(const std::string& pk, uint64_t time, const char* data, uint32_t size) override; - bool Put(uint64_t time, const std::string& value, const Dimensions& dimensions) override; + absl::Status Put(uint64_t time, const std::string& value, const Dimensions& dimensions, + bool put_if_absent) override; bool Get(uint32_t idx, const std::string& pk, uint64_t ts, std::string& value); // NOLINT @@ -183,8 +185,8 @@ class DiskTable : public Table { base::Status Truncate(); - bool Delete(uint32_t idx, const std::string& pk, - const std::optional& start_ts, const std::optional& end_ts) override; + bool Delete(uint32_t idx, const std::string& pk, const std::optional& start_ts, + const std::optional& end_ts) override; uint64_t GetExpireTime(const TTLSt& ttl_st) override; @@ -233,7 +235,7 @@ class DiskTable : public Table { uint64_t GetRecordByteSize() const override { return 0; } uint64_t GetRecordIdxByteSize() override; - int GetCount(uint32_t index, const std::string& pk, uint64_t& count) override; // NOLINT + int GetCount(uint32_t index, const std::string& pk, uint64_t& count) override; // NOLINT private: base::Status Delete(uint32_t idx, const std::string& pk, uint64_t start_ts, const std::optional& end_ts); diff --git a/src/storage/disk_table_test.cc b/src/storage/disk_table_test.cc index 2a4e0d53c98..04a5d6edbb3 100644 --- a/src/storage/disk_table_test.cc +++ b/src/storage/disk_table_test.cc @@ -111,7 +111,7 @@ TEST_F(DiskTableTest, MultiDimensionPut) { mapping.insert(std::make_pair("idx1", 1)); mapping.insert(std::make_pair("idx2", 2)); std::string table_path = FLAGS_hdd_root_path + "/2_1"; - DiskTable* table = new DiskTable("yjtable2", 2, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime, + Table* table = new DiskTable("yjtable2", 2, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime, ::openmldb::common::StorageMode::kHDD, table_path); ASSERT_TRUE(table->Init()); ASSERT_EQ(3, (int64_t)table->GetIdxCnt()); @@ -136,7 +136,7 @@ TEST_F(DiskTableTest, MultiDimensionPut) { d2->set_idx(2); std::string value; ASSERT_EQ(0, sdk_codec.EncodeRow(row, &value)); - bool ok = table->Put(1, value, dimensions); + bool ok = table->Put(1, value, dimensions).ok(); ASSERT_TRUE(ok); // some functions in disk table need to be implemented. // refer to issue #1238 @@ -202,7 +202,7 @@ TEST_F(DiskTableTest, MultiDimensionPut) { row = {"valuea", "valueb", "valuec"}; ASSERT_EQ(0, sdk_codec.EncodeRow(row, &value)); - ASSERT_TRUE(table->Put(2, value, dimensions)); + ASSERT_TRUE(table->Put(2, value, dimensions).ok()); it = table->NewIterator(0, "key2", ticket); it->SeekToFirst(); @@ -223,7 +223,7 @@ TEST_F(DiskTableTest, MultiDimensionPut) { delete it; std::string val; - ASSERT_TRUE(table->Get(1, "key1", 2, val)); + ASSERT_TRUE(reinterpret_cast(table)->Get(1, "key1", 2, val)); data = reinterpret_cast(val.data()); version = codec::RowView::GetSchemaVersion(data); decoder = table->GetVersionDecoder(version); @@ -277,7 +277,7 @@ TEST_F(DiskTableTest, LongPut) { mapping.insert(std::make_pair("idx0", 0)); mapping.insert(std::make_pair("idx1", 1)); std::string table_path = FLAGS_ssd_root_path + "/3_1"; - DiskTable* table = new DiskTable("yjtable3", 3, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime, + Table* table = new DiskTable("yjtable3", 3, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime, ::openmldb::common::StorageMode::kSSD, table_path); auto meta = ::openmldb::test::GetTableMeta({"idx0", "idx1"}); ::openmldb::codec::SDKCodec sdk_codec(meta); @@ -297,7 +297,7 @@ TEST_F(DiskTableTest, LongPut) { std::string value; ASSERT_EQ(0, sdk_codec.EncodeRow(row, &value)); for (int k = 0; k < 10; k++) { - ASSERT_TRUE(table->Put(ts + k, value, dimensions)); + ASSERT_TRUE(table->Put(ts + k, value, dimensions).ok()); } } for (int idx = 0; idx < 10; idx++) { @@ -465,7 +465,7 @@ TEST_F(DiskTableTest, TraverseIterator) { } ASSERT_EQ(20, count); std::string val; - ASSERT_TRUE(table->Get(0, "test98", 9548, val)); + ASSERT_TRUE(reinterpret_cast(table)->Get(0, "test98", 9548, val)); ASSERT_EQ("valu8", val); delete it; delete table; @@ -733,7 +733,7 @@ TEST_F(DiskTableTest, CompactFilter) { std::map mapping; mapping.insert(std::make_pair("idx0", 0)); std::string table_path = FLAGS_hdd_root_path + "/10_1"; - DiskTable* table = new DiskTable("t1", 10, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime, + Table* table = new DiskTable("t1", 10, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime, ::openmldb::common::StorageMode::kHDD, table_path); ASSERT_TRUE(table->Init()); uint64_t cur_time = ::baidu::common::timer::get_micros() / 1000; @@ -754,24 +754,24 @@ TEST_F(DiskTableTest, CompactFilter) { for (int k = 0; k < 5; k++) { std::string value; if (k > 2) { - ASSERT_TRUE(table->Get(key, ts - k - 10 * 60 * 1000, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts - k - 10 * 60 * 1000, value)); ASSERT_EQ("value9", value); } else { - ASSERT_TRUE(table->Get(key, ts - k, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts - k, value)); ASSERT_EQ("value", value); } } } - table->CompactDB(); + reinterpret_cast(table)->CompactDB(); for (int idx = 0; idx < 100; idx++) { std::string key = "test" + std::to_string(idx); uint64_t ts = cur_time; for (int k = 0; k < 5; k++) { std::string value; if (k > 2) { - ASSERT_FALSE(table->Get(key, ts - k - 10 * 60 * 1000, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(key, ts - k - 10 * 60 * 1000, value)); } else { - ASSERT_TRUE(table->Get(key, ts - k, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts - k, value)); ASSERT_EQ("value", value); } } @@ -794,7 +794,7 @@ TEST_F(DiskTableTest, CompactFilterMulTs) { SchemaCodec::SetIndex(table_meta.add_column_key(), "mcc", "mcc", "ts2", ::openmldb::type::kAbsoluteTime, 5, 0); std::string table_path = FLAGS_hdd_root_path + "/11_1"; - DiskTable* table = new DiskTable(table_meta, table_path); + Table* table = new DiskTable(table_meta, table_path); ASSERT_TRUE(table->Init()); codec::SDKCodec codec(table_meta); @@ -818,7 +818,7 @@ TEST_F(DiskTableTest, CompactFilterMulTs) { std::to_string(cur_time - i * 60 * 1000)}; std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - ASSERT_TRUE(table->Put(cur_time - i * 60 * 1000, value, dims)); + ASSERT_TRUE(table->Put(cur_time - i * 60 * 1000, value, dims).ok()); } } else { @@ -828,7 +828,7 @@ TEST_F(DiskTableTest, CompactFilterMulTs) { std::to_string(cur_time - i)}; std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - ASSERT_TRUE(table->Put(cur_time - i, value, dims)); + ASSERT_TRUE(table->Put(cur_time - i, value, dims).ok()); } } } @@ -860,11 +860,11 @@ TEST_F(DiskTableTest, CompactFilterMulTs) { std::string e_value; ASSERT_EQ(0, codec.EncodeRow(row, &e_value)); std::string value; - ASSERT_TRUE(table->Get(0, key, ts - i * 60 * 1000, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, ts - i * 60 * 1000, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(1, key, ts - i * 60 * 1000, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, ts - i * 60 * 1000, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(2, key1, ts - i * 60 * 1000, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, ts - i * 60 * 1000, value)); } } else { @@ -874,15 +874,15 @@ TEST_F(DiskTableTest, CompactFilterMulTs) { std::string e_value; ASSERT_EQ(0, codec.EncodeRow(row, &e_value)); std::string value; - ASSERT_TRUE(table->Get(0, key, ts - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, ts - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(1, key, ts - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, ts - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(2, key1, ts - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, ts - i, value)); } } } - table->CompactDB(); + reinterpret_cast(table)->CompactDB(); iter = table->NewIterator(0, "card0", ticket); iter->SeekToFirst(); while (iter->Valid()) { @@ -908,18 +908,18 @@ TEST_F(DiskTableTest, CompactFilterMulTs) { ASSERT_EQ(0, codec.EncodeRow(row, &e_value)); std::string value; if (i < 3) { - ASSERT_TRUE(table->Get(0, key, cur_ts, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, cur_ts, value)); ASSERT_EQ(e_value, value); } else { - ASSERT_FALSE(table->Get(0, key, cur_ts, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_ts, value)); } if (i < 5) { - ASSERT_TRUE(table->Get(1, key, cur_ts, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, cur_ts, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(2, key1, cur_ts, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, cur_ts, value)); } else { - ASSERT_FALSE(table->Get(1, key, cur_ts, value)); - ASSERT_FALSE(table->Get(2, key1, cur_ts, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(1, key, cur_ts, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(2, key1, cur_ts, value)); } } } else { @@ -929,11 +929,11 @@ TEST_F(DiskTableTest, CompactFilterMulTs) { std::string e_value; ASSERT_EQ(0, codec.EncodeRow(row, &e_value)); std::string value; - ASSERT_TRUE(table->Get(0, key, ts - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, ts - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(1, key, ts - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, ts - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(2, key1, ts - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, ts - i, value)); } } } @@ -955,7 +955,8 @@ TEST_F(DiskTableTest, GcHeadMulTs) { SchemaCodec::SetIndex(table_meta.add_column_key(), "mcc", "mcc", "ts2", ::openmldb::type::kLatestTime, 0, 5); std::string table_path = FLAGS_hdd_root_path + "/12_1"; - DiskTable* table = new DiskTable(table_meta, table_path); + // Table base class doesn't have Get method, cast to DiskTable to call Get + Table* table = new DiskTable(table_meta, table_path); ASSERT_TRUE(table->Init()); codec::SDKCodec codec(table_meta); @@ -980,7 +981,7 @@ TEST_F(DiskTableTest, GcHeadMulTs) { std::to_string(cur_time - i), std::to_string(cur_time - i)}; std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - ASSERT_TRUE(table->Put(cur_time - i, value, dims)); + ASSERT_TRUE(table->Put(cur_time - i, value, dims).ok()); } } Ticket ticket; @@ -1006,15 +1007,15 @@ TEST_F(DiskTableTest, GcHeadMulTs) { ASSERT_EQ(0, codec.EncodeRow(row, &e_value)); std::string value; if (idx == 50 && i > 2) { - ASSERT_FALSE(table->Get(0, key, cur_time - i, value)); - ASSERT_FALSE(table->Get(1, key, cur_time - i, value)); - ASSERT_FALSE(table->Get(2, key1, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(1, key, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value)); } else { - ASSERT_TRUE(table->Get(0, key, cur_time - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, cur_time - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(1, key, cur_time - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, cur_time - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(2, key1, cur_time - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value)); } } } @@ -1041,24 +1042,24 @@ TEST_F(DiskTableTest, GcHeadMulTs) { ASSERT_EQ(0, codec.EncodeRow(row, &e_value)); std::string value; if (idx == 50 && i > 2) { - ASSERT_FALSE(table->Get(0, key, cur_time - i, value)); - ASSERT_FALSE(table->Get(1, key, cur_time - i, value)); - ASSERT_FALSE(table->Get(2, key1, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(1, key, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value)); } else if (i < 3) { - ASSERT_TRUE(table->Get(0, key, cur_time - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, cur_time - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(1, key, cur_time - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, cur_time - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(2, key1, cur_time - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value)); } else if (i < 5) { - ASSERT_FALSE(table->Get(0, key, cur_time - i, value)); - ASSERT_TRUE(table->Get(1, key, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_time - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, cur_time - i, value)); ASSERT_EQ(e_value, value); - ASSERT_TRUE(table->Get(2, key1, cur_time - i, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value)); } else { - ASSERT_FALSE(table->Get(0, key, cur_time - i, value)); - ASSERT_FALSE(table->Get(1, key, cur_time - i, value)); - ASSERT_FALSE(table->Get(2, key1, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(1, key, cur_time - i, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value)); } } } @@ -1089,7 +1090,7 @@ TEST_F(DiskTableTest, GcHead) { uint64_t ts = 9537; for (int k = 0; k < 5; k++) { std::string value; - ASSERT_TRUE(table->Get(key, ts + k, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts + k, value)); if (idx == 10 && k == 2) { ASSERT_EQ("value8", value); } else { @@ -1104,9 +1105,9 @@ TEST_F(DiskTableTest, GcHead) { for (int k = 0; k < 5; k++) { std::string value; if (k < 2) { - ASSERT_FALSE(table->Get(key, ts + k, value)); + ASSERT_FALSE(reinterpret_cast(table)->Get(key, ts + k, value)); } else { - ASSERT_TRUE(table->Get(key, ts + k, value)); + ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts + k, value)); if (idx == 10 && k == 2) { ASSERT_EQ("value8", value); } else { diff --git a/src/storage/key_entry.h b/src/storage/key_entry.h index e8969c9b832..1b5f4778f4f 100644 --- a/src/storage/key_entry.h +++ b/src/storage/key_entry.h @@ -49,11 +49,19 @@ struct DataBlock { delete[] data; data = nullptr; } + + bool EqualWithoutCnt(const DataBlock& other) const { + if (size != other.size) { + return false; + } + // you can improve it ref RowBuilder::InitBuffer header version + return memcmp(data, other.data, size) == 0; + } }; // the desc time comparator struct TimeComparator { - int operator() (uint64_t a, uint64_t b) const { + int operator()(uint64_t a, uint64_t b) const { if (a > b) { return -1; } else if (a == b) { @@ -86,7 +94,6 @@ class KeyEntry { std::atomic count_; }; - } // namespace storage } // namespace openmldb diff --git a/src/storage/mem_table.cc b/src/storage/mem_table.cc index db7578619b7..3a57ffc4e93 100644 --- a/src/storage/mem_table.cc +++ b/src/storage/mem_table.cc @@ -140,21 +140,21 @@ bool MemTable::Put(const std::string& pk, uint64_t time, const char* data, uint3 return true; } -bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions) { +absl::Status MemTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions, bool put_if_absent) { if (dimensions.empty()) { PDLOG(WARNING, "empty dimension. tid %u pid %u", id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": empty dimension")); } if (value.length() < codec::HEADER_LENGTH) { PDLOG(WARNING, "invalid value. tid %u pid %u", id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid value")); } + // inner index pos: -1 means invalid, so it's positive in inner_index_key_map std::map inner_index_key_map; for (auto iter = dimensions.begin(); iter != dimensions.end(); iter++) { int32_t inner_pos = table_index_.GetInnerIndexPos(iter->idx()); if (inner_pos < 0) { - PDLOG(WARNING, "invalid dimension. dimension idx %u, tid %u pid %u", iter->idx(), id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid dimension idx ", iter->idx())); } inner_index_key_map.emplace(inner_pos, iter->key()); } @@ -168,15 +168,13 @@ bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& di uint8_t version = codec::RowView::GetSchemaVersion(data); auto decoder = GetVersionDecoder(version); if (decoder == nullptr) { - PDLOG(WARNING, "invalid schema version %u, tid %u pid %u", version, id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid schema version ", version)); } std::map> ts_value_map; for (const auto& kv : inner_index_key_map) { auto inner_index = table_index_.GetInnerIndex(kv.first); if (!inner_index) { - PDLOG(WARNING, "invalid inner index pos %d. tid %u pid %u", kv.first, id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid inner index pos ", kv.first)); } std::map ts_map; for (const auto& index_def : inner_index->GetIndex()) { @@ -189,13 +187,12 @@ bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& di if (ts_col->IsAutoGenTs()) { ts = time; } else if (decoder->GetInteger(data, ts_col->GetId(), ts_col->GetType(), &ts) != 0) { - PDLOG(WARNING, "get ts failed. tid %u pid %u", id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": get ts failed")); } if (ts < 0) { - PDLOG(WARNING, "ts %ld is negative. tid %u pid %u", ts, id_, pid_); - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": ts is negative ", ts)); } + // TODO(hw): why uint32_t to int32_t? ts_map.emplace(ts_col->GetId(), ts); real_ref_cnt++; } @@ -205,7 +202,7 @@ bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& di } } if (ts_value_map.empty()) { - return false; + return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": empty ts value map")); } auto* block = new DataBlock(real_ref_cnt, value.c_str(), value.length()); for (const auto& kv : inner_index_key_map) { @@ -218,10 +215,12 @@ bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& di seg_idx = ::openmldb::base::hash(kv.second.data(), kv.second.size(), SEED) % seg_cnt_; } Segment* segment = segments_[kv.first][seg_idx]; - segment->Put(::openmldb::base::Slice(kv.second), iter->second, block); + if (!segment->Put(kv.second, iter->second, block, put_if_absent)) { + return absl::AlreadyExistsError("data exists"); // let caller know exists + } } record_byte_size_.fetch_add(GetRecordSize(value.length())); - return true; + return absl::OkStatus(); } bool MemTable::Delete(const ::openmldb::api::LogEntry& entry) { @@ -550,6 +549,7 @@ uint64_t MemTable::GetRecordIdxCnt() { if (!index_def || !index_def->IsReady()) { return record_idx_cnt; } + uint32_t inner_idx = index_def->GetInnerPos(); auto inner_index = table_index_.GetInnerIndex(inner_idx); int32_t ts_col_id = -1; @@ -665,7 +665,7 @@ bool MemTable::AddIndex(const ::openmldb::common::ColumnKey& column_key) { } ts_vec.push_back(ts_iter->second.GetId()); } else { - ts_vec.push_back(DEFUALT_TS_COL_ID); + ts_vec.push_back(DEFAULT_TS_COL_ID); } uint32_t inner_id = table_index_.GetAllInnerIndex()->size(); Segment** seg_arr = new Segment*[seg_cnt_]; @@ -685,7 +685,7 @@ bool MemTable::AddIndex(const ::openmldb::common::ColumnKey& column_key) { auto ts_iter = schema.find(column_key.ts_name()); index_def->SetTsColumn(std::make_shared(ts_iter->second)); } else { - index_def->SetTsColumn(std::make_shared(DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID, + index_def->SetTsColumn(std::make_shared(DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID, ::openmldb::type::kTimestamp, true)); } if (column_key.has_ttl()) { @@ -724,14 +724,14 @@ bool MemTable::DeleteIndex(const std::string& idx_name) { new_table_meta->mutable_column_key(index_def->GetId())->set_flag(1); } std::atomic_store_explicit(&table_meta_, new_table_meta, std::memory_order_release); - index_def->SetStatus(IndexStatus::kWaiting); + index_def->SetStatus(IndexStatus::kWaiting); // let gc do deletion return true; } ::hybridse::vm::WindowIterator* MemTable::NewWindowIterator(uint32_t index) { std::shared_ptr index_def = table_index_.GetIndex(index); if (!index_def || !index_def->IsReady()) { - LOG(WARNING) << "index id " << index << " not found. tid " << id_ << " pid " << pid_; + LOG(WARNING) << "index id " << index << " not found. tid " << id_ << " pid " << pid_; return nullptr; } uint64_t expire_time = 0; diff --git a/src/storage/mem_table.h b/src/storage/mem_table.h index 8ae1964e0ef..e85762a97dc 100644 --- a/src/storage/mem_table.h +++ b/src/storage/mem_table.h @@ -51,7 +51,8 @@ class MemTable : public Table { bool Put(const std::string& pk, uint64_t time, const char* data, uint32_t size) override; - bool Put(uint64_t time, const std::string& value, const Dimensions& dimensions) override; + absl::Status Put(uint64_t time, const std::string& value, const Dimensions& dimensions, + bool put_if_absent) override; bool GetBulkLoadInfo(::openmldb::api::BulkLoadInfoResponse* response); @@ -59,8 +60,8 @@ class MemTable : public Table { const ::google::protobuf::RepeatedPtrField<::openmldb::api::BulkLoadIndex>& indexes); bool Delete(const ::openmldb::api::LogEntry& entry) override; - bool Delete(uint32_t idx, const std::string& key, - const std::optional& start_ts, const std::optional& end_ts); + bool Delete(uint32_t idx, const std::string& key, const std::optional& start_ts, + const std::optional& end_ts); // use the first demission TableIterator* NewIterator(const std::string& pk, Ticket& ticket) override; diff --git a/src/storage/schema.cc b/src/storage/schema.cc index 7efff1d35a4..3250a047a8b 100644 --- a/src/storage/schema.cc +++ b/src/storage/schema.cc @@ -216,7 +216,7 @@ int TableIndex::ParseFromMeta(const ::openmldb::api::TableMeta& table_meta) { index->SetTsColumn(col_map[ts_name]); } else { // set default ts col - index->SetTsColumn(std::make_shared(DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID, + index->SetTsColumn(std::make_shared(DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID, ::openmldb::type::kTimestamp, true)); } if (column_key.has_ttl()) { @@ -232,7 +232,7 @@ int TableIndex::ParseFromMeta(const ::openmldb::api::TableMeta& table_meta) { // add default dimension if (indexs_->empty()) { auto index = std::make_shared("idx0", 0); - index->SetTsColumn(std::make_shared(DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID, + index->SetTsColumn(std::make_shared(DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID, ::openmldb::type::kTimestamp, true)); if (AddIndex(index) < 0) { DLOG(WARNING) << "add index failed"; diff --git a/src/storage/schema.h b/src/storage/schema.h index 2143744122e..39359761ed9 100644 --- a/src/storage/schema.h +++ b/src/storage/schema.h @@ -31,8 +31,8 @@ namespace openmldb::storage { static constexpr uint32_t MAX_INDEX_NUM = 200; -static constexpr uint32_t DEFUALT_TS_COL_ID = UINT32_MAX; -static constexpr const char* DEFUALT_TS_COL_NAME = "default_ts"; +static constexpr uint32_t DEFAULT_TS_COL_ID = UINT32_MAX; +static constexpr const char* DEFAULT_TS_COL_NAME = "default_ts"; enum TTLType { kAbsoluteTime = 1, kRelativeTime = 2, kLatestTime = 3, kAbsAndLat = 4, kAbsOrLat = 5 }; @@ -163,7 +163,7 @@ class ColumnDef { return false; } - inline bool IsAutoGenTs() const { return id_ == DEFUALT_TS_COL_ID; } + inline bool IsAutoGenTs() const { return id_ == DEFAULT_TS_COL_ID; } private: std::string name_; diff --git a/src/storage/schema_test.cc b/src/storage/schema_test.cc index 1c169697634..77840c13e93 100644 --- a/src/storage/schema_test.cc +++ b/src/storage/schema_test.cc @@ -233,9 +233,9 @@ TEST_F(SchemaTest, TsAndDefaultTs) { ::openmldb::storage::kAbsoluteTime); AssertIndex(*(table_index.GetIndex("key2")), "key2", "col1", "col7", 7, 10, 0, ::openmldb::storage::kAbsoluteTime); AssertIndex(*(table_index.GetIndex("key3")), "key3", "col2", "col6", 6, 10, 0, ::openmldb::storage::kAbsoluteTime); - AssertIndex(*(table_index.GetIndex("key4")), "key4", "col2", DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID, + AssertIndex(*(table_index.GetIndex("key4")), "key4", "col2", DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID, 10, 0, ::openmldb::storage::kAbsoluteTime); - AssertIndex(*(table_index.GetIndex("key5")), "key5", "col3", DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID, + AssertIndex(*(table_index.GetIndex("key5")), "key5", "col3", DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID, 10, 0, ::openmldb::storage::kAbsoluteTime); auto inner_index = table_index.GetAllInnerIndex(); ASSERT_EQ(inner_index->size(), 3u); @@ -243,10 +243,10 @@ TEST_F(SchemaTest, TsAndDefaultTs) { std::vector ts_vec0 = {6, 7}; AssertInnerIndex(*(table_index.GetInnerIndex(0)), 0, index0, ts_vec0); std::vector index1 = {"key3", "key4"}; - std::vector ts_vec1 = {6, DEFUALT_TS_COL_ID}; + std::vector ts_vec1 = {6, DEFAULT_TS_COL_ID}; AssertInnerIndex(*(table_index.GetInnerIndex(1)), 1, index1, ts_vec1); std::vector index2 = {"key5"}; - std::vector ts_vec2 = {DEFUALT_TS_COL_ID}; + std::vector ts_vec2 = {DEFAULT_TS_COL_ID}; AssertInnerIndex(*(table_index.GetInnerIndex(2)), 2, index2, ts_vec2); } diff --git a/src/storage/segment.cc b/src/storage/segment.cc index d79b6e85681..8255d27b7bd 100644 --- a/src/storage/segment.cc +++ b/src/storage/segment.cc @@ -15,7 +15,9 @@ */ #include "storage/segment.h" + #include + #include #include "base/glog_wrapper.h" @@ -64,9 +66,7 @@ Segment::Segment(uint8_t height, const std::vector& ts_idx_vec) } } -Segment::~Segment() { - delete entries_; -} +Segment::~Segment() { delete entries_; } void Segment::Release(StatisticsInfo* statistics_info) { std::unique_ptr it(entries_->NewIterator()); @@ -98,9 +98,7 @@ void Segment::Release(StatisticsInfo* statistics_info) { } } -void Segment::ReleaseAndCount(StatisticsInfo* statistics_info) { - Release(statistics_info); -} +void Segment::ReleaseAndCount(StatisticsInfo* statistics_info) { Release(statistics_info); } void Segment::ReleaseAndCount(const std::vector& id_vec, StatisticsInfo* statistics_info) { if (ts_cnt_ <= 1) { @@ -135,25 +133,28 @@ void Segment::ReleaseAndCount(const std::vector& id_vec, StatisticsInfo* } } -void Segment::Put(const Slice& key, uint64_t time, const char* data, uint32_t size) { +void Segment::Put(const Slice& key, uint64_t time, const char* data, uint32_t size, bool put_if_absent, + bool check_all_time) { if (ts_cnt_ > 1) { return; } auto* db = new DataBlock(1, data, size); - Put(key, time, db); + Put(key, time, db, put_if_absent, check_all_time); } -void Segment::Put(const Slice& key, uint64_t time, DataBlock* row) { +bool Segment::Put(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent, bool check_all_time) { if (ts_cnt_ > 1) { - return; + LOG(ERROR) << "wrong call"; + return false; } std::lock_guard lock(mu_); - PutUnlock(key, time, row); + return PutUnlock(key, time, row, put_if_absent, check_all_time); } -void Segment::PutUnlock(const Slice& key, uint64_t time, DataBlock* row) { +bool Segment::PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent, bool check_all_time) { void* entry = nullptr; uint32_t byte_size = 0; + // one key just one entry int ret = entries_->Get(key, entry); if (ret < 0 || entry == nullptr) { char* pk = new char[key.size()]; @@ -164,12 +165,17 @@ void Segment::PutUnlock(const Slice& key, uint64_t time, DataBlock* row) { uint8_t height = entries_->Insert(skey, entry); byte_size += GetRecordPkIdxSize(height, key.size(), key_entry_max_height_); pk_cnt_.fetch_add(1, std::memory_order_relaxed); + // no need to check if absent when first put + } else if (put_if_absent && ListContains(reinterpret_cast(entry), time, row, check_all_time)) { + return false; } + idx_cnt_vec_[0]->fetch_add(1, std::memory_order_relaxed); uint8_t height = reinterpret_cast(entry)->entries.Insert(time, row); reinterpret_cast(entry)->count_.fetch_add(1, std::memory_order_relaxed); byte_size += GetRecordTsIdxSize(height); idx_byte_size_.fetch_add(byte_size, std::memory_order_relaxed); + return true; } void Segment::BulkLoadPut(unsigned int key_entry_id, const Slice& key, uint64_t time, DataBlock* row) { @@ -201,16 +207,17 @@ void Segment::BulkLoadPut(unsigned int key_entry_id, const Slice& key, uint64_t } } -void Segment::Put(const Slice& key, const std::map& ts_map, DataBlock* row) { - uint32_t ts_size = ts_map.size(); - if (ts_size == 0) { - return; +bool Segment::Put(const Slice& key, const std::map& ts_map, DataBlock* row, bool put_if_absent) { + if (ts_map.empty()) { + return false; } if (ts_cnt_ == 1) { + bool ret = false; if (auto pos = ts_map.find(ts_idx_map_.begin()->first); pos != ts_map.end()) { - Put(key, pos->second, row); + // TODO(hw): why ts_map key is int32_t, default ts is uint32_t? + ret = Put(key, pos->second, row, put_if_absent, pos->first == DEFAULT_TS_COL_ID); } - return; + return ret; } void* entry_arr = nullptr; std::lock_guard lock(mu_); @@ -237,12 +244,16 @@ void Segment::Put(const Slice& key, const std::map& ts_map, D } } auto entry = reinterpret_cast(entry_arr)[pos->second]; + if (put_if_absent && ListContains(entry, kv.second, row, pos->first == DEFAULT_TS_COL_ID)) { + return false; + } uint8_t height = entry->entries.Insert(kv.second, row); entry->count_.fetch_add(1, std::memory_order_relaxed); byte_size += GetRecordTsIdxSize(height); idx_byte_size_.fetch_add(byte_size, std::memory_order_relaxed); idx_cnt_vec_[pos->second]->fetch_add(1, std::memory_order_relaxed); } + return true; } bool Segment::Delete(const std::optional& idx, const Slice& key) { @@ -289,8 +300,8 @@ bool Segment::Delete(const std::optional& idx, const Slice& key) { return true; } -bool Segment::Delete(const std::optional& idx, const Slice& key, - uint64_t ts, const std::optional& end_ts) { +bool Segment::Delete(const std::optional& idx, const Slice& key, uint64_t ts, + const std::optional& end_ts) { void* entry = nullptr; if (entries_->Get(key, entry) < 0 || entry == nullptr) { return true; @@ -347,7 +358,7 @@ bool Segment::Delete(const std::optional& idx, const Slice& key, } void Segment::FreeList(uint32_t ts_idx, ::openmldb::base::Node* node, - StatisticsInfo* statistics_info) { + StatisticsInfo* statistics_info) { while (node != nullptr) { statistics_info->IncrIdxCnt(ts_idx); ::openmldb::base::Node* tmp = node; @@ -365,7 +376,6 @@ void Segment::FreeList(uint32_t ts_idx, ::openmldb::base::Node it(entry->entries.NewIterator()); + if (check_all_time) { + it->SeekToFirst(); + while (it->Valid()) { + if (it->GetValue()->EqualWithoutCnt(*row)) { + return true; + } + it->Next(); + } + } else { + // less than but desc time comparator, so it's <= time(not valid if empty or all > time), and get smaller by + // next + it->Seek(time); + while (it->Valid()) { + // key > time is just a protection, normally it should not happen + if (it->GetKey() < time || it->GetKey() > time) { + break; // no entry == time, or all entries == time have been checked + } + if (it->GetValue()->EqualWithoutCnt(*row)) { + return true; + } + it->Next(); + } + } + return false; +} + // fast gc with no global pause void Segment::Gc4TTL(const uint64_t time, StatisticsInfo* statistics_info) { uint64_t consumed = ::baidu::common::timer::get_micros(); @@ -606,8 +645,7 @@ void Segment::Gc4TTL(const uint64_t time, StatisticsInfo* statistics_info) { if (node == nullptr) { continue; } else if (node->GetKey() > time) { - DEBUGLOG("[Gc4TTL] segment gc with key %lu need not ttl, last node key %lu", - time, node->GetKey()); + DEBUGLOG("[Gc4TTL] segment gc with key %lu need not ttl, last node key %lu", time, node->GetKey()); continue; } node = nullptr; @@ -648,8 +686,7 @@ void Segment::Gc4TTLAndHead(const uint64_t time, const uint64_t keep_cnt, Statis if (node == nullptr) { continue; } else if (node->GetKey() > time) { - DEBUGLOG("[Gc4TTLAndHead] segment gc with key %lu need not ttl, last node key %lu", - time, node->GetKey()); + DEBUGLOG("[Gc4TTLAndHead] segment gc with key %lu need not ttl, last node key %lu", time, node->GetKey()); continue; } node = nullptr; @@ -663,8 +700,8 @@ void Segment::Gc4TTLAndHead(const uint64_t time, const uint64_t keep_cnt, Statis FreeList(0, node, statistics_info); entry->count_.fetch_sub(statistics_info->GetIdxCnt(0) - cur_idx_cnt, std::memory_order_relaxed); } - DEBUGLOG("[Gc4TTLAndHead] segment gc time %lu and keep cnt %lu consumed %lu, count %lu", - time, keep_cnt, (::baidu::common::timer::get_micros() - consumed) / 1000, statistics_info->GetIdxCnt(0) - old); + DEBUGLOG("[Gc4TTLAndHead] segment gc time %lu and keep cnt %lu consumed %lu, count %lu", time, keep_cnt, + (::baidu::common::timer::get_micros() - consumed) / 1000, statistics_info->GetIdxCnt(0) - old); idx_cnt_vec_[0]->fetch_sub(statistics_info->GetIdxCnt(0) - old, std::memory_order_relaxed); } @@ -709,8 +746,8 @@ void Segment::Gc4TTLOrHead(const uint64_t time, const uint64_t keep_cnt, Statist FreeList(0, node, statistics_info); entry->count_.fetch_sub(statistics_info->GetIdxCnt(0) - cur_idx_cnt, std::memory_order_relaxed); } - DEBUGLOG("[Gc4TTLAndHead] segment gc time %lu and keep cnt %lu consumed %lu, count %lu", - time, keep_cnt, (::baidu::common::timer::get_micros() - consumed) / 1000, statistics_info->GetIdxCnt(0) - old); + DEBUGLOG("[Gc4TTLAndHead] segment gc time %lu and keep cnt %lu consumed %lu, count %lu", time, keep_cnt, + (::baidu::common::timer::get_micros() - consumed) / 1000, statistics_info->GetIdxCnt(0) - old); idx_cnt_vec_[0]->fetch_sub(statistics_info->GetIdxCnt(0) - old, std::memory_order_relaxed); } @@ -754,8 +791,8 @@ MemTableIterator* Segment::NewIterator(const Slice& key, Ticket& ticket, type::C return new MemTableIterator(reinterpret_cast(entry)->entries.NewIterator(), compress_type); } -MemTableIterator* Segment::NewIterator(const Slice& key, uint32_t idx, - Ticket& ticket, type::CompressType compress_type) { +MemTableIterator* Segment::NewIterator(const Slice& key, uint32_t idx, Ticket& ticket, + type::CompressType compress_type) { auto pos = ts_idx_map_.find(idx); if (pos == ts_idx_map_.end()) { return new MemTableIterator(nullptr, compress_type); diff --git a/src/storage/segment.h b/src/storage/segment.h index fe58dd893a0..11322483832 100644 --- a/src/storage/segment.h +++ b/src/storage/segment.h @@ -70,20 +70,19 @@ class Segment { Segment(uint8_t height, const std::vector& ts_idx_vec); ~Segment(); - // Put time data - void Put(const Slice& key, uint64_t time, const char* data, uint32_t size); + // legacy interface called by memtable and ut + void Put(const Slice& key, uint64_t time, const char* data, uint32_t size, bool put_if_absent = false, + bool check_all_time = false); - void Put(const Slice& key, uint64_t time, DataBlock* row); - - void PutUnlock(const Slice& key, uint64_t time, DataBlock* row); + bool Put(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent = false, bool check_all_time = false); void BulkLoadPut(unsigned int key_entry_id, const Slice& key, uint64_t time, DataBlock* row); - - void Put(const Slice& key, const std::map& ts_map, DataBlock* row); + // main put method + bool Put(const Slice& key, const std::map& ts_map, DataBlock* row, bool put_if_absent = false); bool Delete(const std::optional& idx, const Slice& key); - bool Delete(const std::optional& idx, const Slice& key, - uint64_t ts, const std::optional& end_ts); + bool Delete(const std::optional& idx, const Slice& key, uint64_t ts, + const std::optional& end_ts); void Release(StatisticsInfo* statistics_info); @@ -97,12 +96,10 @@ class Segment { void GcAllType(const std::map& ttl_st_map, StatisticsInfo* statistics_info); MemTableIterator* NewIterator(const Slice& key, Ticket& ticket, type::CompressType compress_type); // NOLINT - MemTableIterator* NewIterator(const Slice& key, uint32_t idx, - Ticket& ticket, type::CompressType compress_type); // NOLINT + MemTableIterator* NewIterator(const Slice& key, uint32_t idx, Ticket& ticket, // NOLINT + type::CompressType compress_type); - uint64_t GetIdxCnt() const { - return idx_cnt_vec_[0]->load(std::memory_order_relaxed); - } + uint64_t GetIdxCnt() const { return idx_cnt_vec_[0]->load(std::memory_order_relaxed); } int GetIdxCnt(uint32_t ts_idx, uint64_t& ts_cnt) { // NOLINT uint32_t real_idx = 0; @@ -145,10 +142,14 @@ class Segment { void ReleaseAndCount(const std::vector& id_vec, StatisticsInfo* statistics_info); private: - void FreeList(uint32_t ts_idx, ::openmldb::base::Node* node, - StatisticsInfo* statistics_info); + void FreeList(uint32_t ts_idx, ::openmldb::base::Node* node, StatisticsInfo* statistics_info); void SplitList(KeyEntry* entry, uint64_t ts, ::openmldb::base::Node** node); + bool ListContains(KeyEntry* entry, uint64_t time, DataBlock* row, bool check_all_time); + + bool PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent = false, + bool check_all_time = false); + private: KeyEntries* entries_; std::mutex mu_; diff --git a/src/storage/segment_test.cc b/src/storage/segment_test.cc index c51c0984473..e43461c47e6 100644 --- a/src/storage/segment_test.cc +++ b/src/storage/segment_test.cc @@ -424,6 +424,82 @@ TEST_F(SegmentTest, TestDeleteRange) { CheckStatisticsInfo(CreateStatisticsInfo(20, 1012, 20 * (6 + sizeof(DataBlock))), gc_info); } +TEST_F(SegmentTest, PutIfAbsent) { + { + Segment segment(8); // so ts_cnt_ == 1 + // check all time == false + segment.Put("PK", 1, "test1", 5, true); + segment.Put("PK", 1, "test2", 5, true); // even key&time is the same, different value means different record + ASSERT_EQ(2, (int64_t)segment.GetIdxCnt()); + ASSERT_EQ(1, (int64_t)segment.GetPkCnt()); + segment.Put("PK", 2, "test3", 5, true); + segment.Put("PK", 2, "test4", 5, true); + segment.Put("PK", 3, "test5", 5, true); + segment.Put("PK", 3, "test6", 5, true); + ASSERT_EQ(6, (int64_t)segment.GetIdxCnt()); + // insert exists rows + segment.Put("PK", 2, "test3", 5, true); + segment.Put("PK", 1, "test1", 5, true); + segment.Put("PK", 1, "test2", 5, true); + segment.Put("PK", 3, "test6", 5, true); + ASSERT_EQ(6, (int64_t)segment.GetIdxCnt()); + // new rows + segment.Put("PK", 2, "test7", 5, true); + ASSERT_EQ(7, (int64_t)segment.GetIdxCnt()); + segment.Put("PK", 0, "test8", 5, true); // seek to last, next is empty + ASSERT_EQ(8, (int64_t)segment.GetIdxCnt()); + } + + { + // support when ts_cnt_ != 1 too + std::vector ts_idx_vec = {1, 3}; + Segment segment(8, ts_idx_vec); + ASSERT_EQ(2, (int64_t)segment.GetTsCnt()); + std::string key = "PK"; + uint64_t ts = 1669013677221000; + // the same ts + for (int j = 0; j < 2; j++) { + DataBlock* data = new DataBlock(2, key.c_str(), key.length()); + std::map ts_map = {{1, ts}, {3, ts}}; + segment.Put(Slice(key), ts_map, data, true); + } + ASSERT_EQ(1, GetCount(&segment, 1)); + ASSERT_EQ(1, GetCount(&segment, 3)); + } + + { + // put ts_map contains DEFAULT_TS_COL_ID + std::vector ts_idx_vec = {DEFAULT_TS_COL_ID}; + Segment segment(8, ts_idx_vec); + ASSERT_EQ(1, (int64_t)segment.GetTsCnt()); + std::string key = "PK"; + std::map ts_map = {{DEFAULT_TS_COL_ID, 100}}; // cur time == 100 + auto* block = new DataBlock(1, "test1", 5); + segment.Put(Slice(key), ts_map, block, true); + ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID)); + ts_map = {{DEFAULT_TS_COL_ID, 200}}; + block = new DataBlock(1, "test1", 5); + segment.Put(Slice(key), ts_map, block, true); + ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID)); + } + + { + // put ts_map contains DEFAULT_TS_COL_ID + std::vector ts_idx_vec = {DEFAULT_TS_COL_ID, 1, 3}; + Segment segment(8, ts_idx_vec); + ASSERT_EQ(3, (int64_t)segment.GetTsCnt()); + std::string key = "PK"; + std::map ts_map = {{DEFAULT_TS_COL_ID, 100}}; // cur time == 100 + auto* block = new DataBlock(1, "test1", 5); + segment.Put(Slice(key), ts_map, block, true); + ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID)); + ts_map = {{DEFAULT_TS_COL_ID, 200}}; + block = new DataBlock(1, "test1", 5); + segment.Put(Slice(key), ts_map, block, true); + ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID)); + } +} + } // namespace storage } // namespace openmldb diff --git a/src/storage/snapshot_test.cc b/src/storage/snapshot_test.cc index 910a8bc7724..e9dd679eafc 100644 --- a/src/storage/snapshot_test.cc +++ b/src/storage/snapshot_test.cc @@ -1085,7 +1085,7 @@ TEST_F(SnapshotTest, MakeSnapshotAbsOrLat) { SchemaCodec::SetColumnDesc(table_meta->add_column_desc(), "value", ::openmldb::type::kString); SchemaCodec::SetIndex(table_meta->add_column_key(), "index1", "card|merchant", "", ::openmldb::type::kAbsOrLat, 0, 1); - std::shared_ptr table = std::make_shared(*table_meta); + std::shared_ptr
table = std::make_shared(*table_meta); table->Init(); LogParts* log_part = new LogParts(12, 4, scmp); @@ -1119,7 +1119,7 @@ TEST_F(SnapshotTest, MakeSnapshotAbsOrLat) { google::protobuf::RepeatedPtrField<::openmldb::api::Dimension> d_list; ::openmldb::api::Dimension* d_ptr2 = d_list.Add(); d_ptr2->CopyFrom(dimensions); - ASSERT_EQ(table->Put(i + 1, *result, d_list), true); + ASSERT_EQ(table->Put(i + 1, *result, d_list).ok(), true); } table->SchedGc(); diff --git a/src/storage/table.h b/src/storage/table.h index 0766e4cf6c4..4c4a1f011f7 100644 --- a/src/storage/table.h +++ b/src/storage/table.h @@ -22,6 +22,7 @@ #include #include +#include "absl/status/status.h" #include "codec/codec.h" #include "proto/tablet.pb.h" #include "storage/iterator.h" @@ -50,17 +51,16 @@ class Table { int InitColumnDesc(); virtual bool Put(const std::string& pk, uint64_t time, const char* data, uint32_t size) = 0; + // DO NOT set different default value in derived class + virtual absl::Status Put(uint64_t time, const std::string& value, const Dimensions& dimensions, + bool put_if_absent = false) = 0; - virtual bool Put(uint64_t time, const std::string& value, const Dimensions& dimensions) = 0; - - bool Put(const ::openmldb::api::LogEntry& entry) { - return Put(entry.ts(), entry.value(), entry.dimensions()); - } + bool Put(const ::openmldb::api::LogEntry& entry) { return Put(entry.ts(), entry.value(), entry.dimensions()).ok(); } virtual bool Delete(const ::openmldb::api::LogEntry& entry) = 0; - virtual bool Delete(uint32_t idx, const std::string& key, - const std::optional& start_ts, const std::optional& end_ts) = 0; + virtual bool Delete(uint32_t idx, const std::string& key, const std::optional& start_ts, + const std::optional& end_ts) = 0; virtual TableIterator* NewIterator(const std::string& pk, Ticket& ticket) = 0; // NOLINT @@ -88,9 +88,7 @@ class Table { } return ""; } - inline ::openmldb::common::StorageMode GetStorageMode() const { - return storage_mode_; - } + inline ::openmldb::common::StorageMode GetStorageMode() const { return storage_mode_; } inline uint32_t GetId() const { return id_; } inline uint32_t GetIdxCnt() const { return table_index_.Size(); } @@ -173,7 +171,7 @@ class Table { virtual uint64_t GetRecordByteSize() const = 0; virtual uint64_t GetRecordIdxByteSize() = 0; - virtual int GetCount(uint32_t index, const std::string& pk, uint64_t& count) = 0; // NOLINT + virtual int GetCount(uint32_t index, const std::string& pk, uint64_t& count) = 0; // NOLINT protected: void UpdateTTL(); diff --git a/src/storage/table_iterator_test.cc b/src/storage/table_iterator_test.cc index 7ba932422e1..3af20940266 100644 --- a/src/storage/table_iterator_test.cc +++ b/src/storage/table_iterator_test.cc @@ -450,7 +450,7 @@ TEST_P(TableIteratorTest, SeekNonExistent) { ASSERT_EQ(0, now - wit->GetKey()); } -INSTANTIATE_TEST_CASE_P(TestMemAndHDD, TableIteratorTest, +INSTANTIATE_TEST_SUITE_P(TestMemAndHDD, TableIteratorTest, ::testing::Values(::openmldb::common::kMemory, ::openmldb::common::kHDD)); } // namespace storage diff --git a/src/storage/table_test.cc b/src/storage/table_test.cc index 251e92986c6..43b3508822e 100644 --- a/src/storage/table_test.cc +++ b/src/storage/table_test.cc @@ -198,7 +198,7 @@ TEST_P(TableTest, MultiDimissionPut0) { ::openmldb::codec::SDKCodec sdk_codec(meta); std::string result; sdk_codec.EncodeRow({"d0", "d1", "d2"}, &result); - bool ok = table->Put(1, result, dimensions); + bool ok = table->Put(1, result, dimensions).ok(); ASSERT_TRUE(ok); // some functions in disk table need to be implemented. // refer to issue #1238 @@ -808,7 +808,7 @@ TEST_P(TableTest, TableIteratorTS) { dim->set_key(row[1]); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } TableIterator* it = table->NewTraverseIterator(0); it->SeekToFirst(); @@ -921,7 +921,7 @@ TEST_P(TableTest, TraverseIteratorCount) { dim->set_key(row[1]); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } TableIterator* it = table->NewTraverseIterator(0); it->SeekToFirst(); @@ -1048,7 +1048,7 @@ TEST_P(TableTest, AbsAndLatSetGet) { dim->set_key("mcc"); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } // test get and set ttl ASSERT_EQ(10, (int64_t)table->GetIndex(0)->GetTTL()->abs_ttl / (10 * 6000)); @@ -1149,7 +1149,7 @@ TEST_P(TableTest, AbsOrLatSetGet) { dim->set_key("mcc"); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } // test get and set ttl ASSERT_EQ(10, (int64_t)table->GetIndex(0)->GetTTL()->abs_ttl / (10 * 6000)); @@ -1562,7 +1562,7 @@ TEST_P(TableTest, TraverseIteratorCountWithLimit) { dim->set_key(row[1]); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } TableIterator* it = table->NewTraverseIterator(0); @@ -1669,7 +1669,7 @@ TEST_P(TableTest, TSColIDLength) { dim1->set_key(row[0]); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } TableIterator* it = table->NewTraverseIterator(0); @@ -1727,7 +1727,7 @@ TEST_P(TableTest, MultiDimensionPutTS) { dim->set_key(row[1]); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } TableIterator* it = table->NewTraverseIterator(0); it->SeekToFirst(); @@ -1781,7 +1781,7 @@ TEST_P(TableTest, MultiDimensionPutTS1) { dim->set_key(row[1]); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } TableIterator* it = table->NewTraverseIterator(0); it->SeekToFirst(); @@ -1823,7 +1823,7 @@ TEST_P(TableTest, MultiDimissionPutTS2) { ::openmldb::codec::SDKCodec sdk_codec(meta); std::string result; sdk_codec.EncodeRow({"d0", "d1", "d2"}, &result); - bool ok = table->Put(100, result, dimensions); + bool ok = table->Put(100, result, dimensions).ok(); ASSERT_TRUE(ok); TableIterator* it = table->NewTraverseIterator(0); @@ -1885,7 +1885,7 @@ TEST_P(TableTest, AbsAndLat) { } std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - table->Put(0, value, request.dimensions()); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } for (int i = 0; i <= 5; i++) { @@ -1938,10 +1938,11 @@ TEST_P(TableTest, NegativeTs) { dim->set_key(row[0]); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - ASSERT_FALSE(table->Put(0, value, request.dimensions())); + auto st = table->Put(0, value, request.dimensions()); + ASSERT_TRUE(absl::IsInvalidArgument(st)) << st.ToString(); } -INSTANTIATE_TEST_CASE_P(TestMemAndHDD, TableTest, +INSTANTIATE_TEST_SUITE_P(TestMemAndHDD, TableTest, ::testing::Values(::openmldb::common::kMemory, ::openmldb::common::kHDD)); } // namespace storage diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index 8691cf2f90b..8b2b446e874 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -767,7 +767,8 @@ void TabletImpl::Put(RpcController* controller, const ::openmldb::api::PutReques if (request->ts_dimensions_size() > 0) { entry.mutable_ts_dimensions()->CopyFrom(request->ts_dimensions()); } - bool ok = false; + + absl::Status st; if (request->dimensions_size() > 0) { int32_t ret_code = CheckDimessionPut(request, table->GetIdxCnt()); if (ret_code != 0) { @@ -776,16 +777,27 @@ void TabletImpl::Put(RpcController* controller, const ::openmldb::api::PutReques return; } DLOG(INFO) << "put data to tid " << tid << " pid " << pid << " with key " << request->dimensions(0).key(); - ok = table->Put(entry.ts(), entry.value(), entry.dimensions()); + // 1. normal put: ok, invalid data + // 2. put if absent: ok, exists but ignore, invalid data + st = table->Put(entry.ts(), entry.value(), entry.dimensions(), request->put_if_absent()); } - if (!ok) { + + if (!st.ok()) { + if (request->put_if_absent() && absl::IsAlreadyExists(st)) { + // not a failure but shounld't write log entry + response->set_code(::openmldb::base::ReturnCode::kOk); + response->set_msg("exists but ignore"); + return; + } + LOG(WARNING) << st.ToString(); response->set_code(::openmldb::base::ReturnCode::kPutFailed); - response->set_msg("put failed"); + response->set_msg(st.ToString()); return; } response->set_code(::openmldb::base::ReturnCode::kOk); std::shared_ptr replicator; + bool ok = false; do { replicator = GetReplicator(request->tid(), request->pid()); if (!replicator) { @@ -2220,9 +2232,8 @@ void TabletImpl::AppendEntries(RpcController* controller, const ::openmldb::api: return; } if (entry.has_method_type() && entry.method_type() == ::openmldb::api::MethodType::kDelete) { - table->Delete(entry); - } - if (!table->Put(entry)) { + table->Delete(entry); // TODO(hw): error handle + } else if (!table->Put(entry)) { // put if type is not delete PDLOG(WARNING, "fail to put entry. tid %u pid %u", tid, pid); response->set_code(::openmldb::base::ReturnCode::kFailToAppendEntriesToReplicator); response->set_msg("fail to append entry to table"); diff --git a/src/tablet/tablet_impl_func_test.cc b/src/tablet/tablet_impl_func_test.cc index c84729f288d..c07084a396d 100644 --- a/src/tablet/tablet_impl_func_test.cc +++ b/src/tablet/tablet_impl_func_test.cc @@ -89,7 +89,7 @@ void CreateBaseTable(::openmldb::storage::Table*& table, // NOLINT dim->set_key(row[1]); std::string value; ASSERT_EQ(0, codec.EncodeRow(row, &value)); - ASSERT_TRUE(table->Put(0, value, request.dimensions())); + ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok()); } return; } @@ -389,7 +389,7 @@ TEST_P(TabletFuncTest, GetTimeIndex_ts1_iterator) { RunGetTimeIndexAssert(&query_its, base_ts, base_ts - 100); } -INSTANTIATE_TEST_CASE_P(TabletMemAndHDD, TabletFuncTest, +INSTANTIATE_TEST_SUITE_P(TabletMemAndHDD, TabletFuncTest, ::testing::Values(::openmldb::common::kMemory, ::openmldb::common::kHDD, ::openmldb::common::kSSD));