diff --git a/demo/usability_testing/data_mocker.py b/demo/usability_testing/data_mocker.py
index f873daec9dc..6729ef0b70b 100644
--- a/demo/usability_testing/data_mocker.py
+++ b/demo/usability_testing/data_mocker.py
@@ -6,6 +6,7 @@
from typing import Optional
import numpy as np
import pandas as pd
+import dateutil
# to support save csv, and faster parquet, we don't use faker-cli directly
@@ -146,8 +147,9 @@ def type_converter(sql_type):
if sql_type in ['varchar', 'string']:
# TODO(hw): set max length
return 'pystr', {}
+ # timestamp should > 0 cuz tablet insert will check it, use utc
if sql_type in ['date', 'timestamp']:
- return 'iso8601', {}
+ return 'iso8601', {"tzinfo": dateutil.tz.UTC}
if sql_type in ['float', 'double']:
return 'pyfloat', ranges[sql_type]
return 'py' + sql_type, {}
diff --git a/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md b/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md
index a44f699eed3..0113ef730b0 100644
--- a/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md
+++ b/docs/zh/openmldb_sql/ddl/CREATE_TABLE_STATEMENT.md
@@ -233,8 +233,8 @@ IndexOption ::=
| ----------- | ------------------------------------------------------------ | ---------------------------------------------------- | ------------------------------------------------------------ |
| `ABSOLUTE` | TTL的值代表过期时间。配置值为时间段如`100m, 12h, 1d, 365d`。最大可以配置的过期时间为`15768000m`(即30年) | 当记录过期时,会被淘汰。 | `INDEX(KEY=col1, TS=std_time, TTL_TYPE=absolute, TTL=100m)`
OpenMLDB将会删除100分钟之前的数据。 |
| `LATEST` | TTL的值代表最大存活条数。即同一个索引下面,最大允许存在的数据条数。最大可以配置1000条 | 记录超过最大条数时,会被淘汰。 | `INDEX(KEY=col1, TS=std_time, TTL_TYPE=LATEST, TTL=10)`。OpenMLDB只会保留最近10条记录,删除以前的记录。 |
-| `ABSORLAT` | 配置过期时间和最大存活条数。配置值是一个2元组,形如`(100m, 10), (1d, 1)`。最大可以配置`(15768000m, 1000)`。 | 当且仅当记录过期**或**记录超过最大条数时,才会淘汰。 | `INDEX(key=c1, ts=c6, ttl=(120min, 100), ttl_type=absorlat)`。当记录超过100条,**或者**当记录过期时,会被淘汰 |
-| `ABSANDLAT` | 配置过期时间和最大存活条数。配置值是一个2元组,形如`(100m, 10), (1d, 1)`。最大可以配置`(15768000m, 1000)`。 | 当记录过期**且**记录超过最大条数时,记录会被淘汰。 | `INDEX(key=c1, ts=c6, ttl=(120min, 100), ttl_type=absandlat)`。当记录超过100条,**而且**记录过期时,会被淘汰 |
+| `ABSORLAT` | 配置过期时间和最大存活条数。配置值是一个2元组,形如`(100m, 10), (1d, 1)`。最大可以配置`(15768000m, 1000)`。 | 当且仅当记录过期**或**记录超过最大条数时,才会淘汰。 | `INDEX(key=c1, ts=c6, ttl=(120m, 100), ttl_type=absorlat)`。当记录超过100条,**或者**当记录过期时,会被淘汰 |
+| `ABSANDLAT` | 配置过期时间和最大存活条数。配置值是一个2元组,形如`(100m, 10), (1d, 1)`。最大可以配置`(15768000m, 1000)`。 | 当记录过期**且**记录超过最大条数时,记录会被淘汰。 | `INDEX(key=c1, ts=c6, ttl=(120m, 100), ttl_type=absandlat)`。当记录超过100条,**而且**记录过期时,会被淘汰 |
```{note}
最大过期时间和最大存活条数的限制,是出于性能考虑。如果你一定要配置更大的TTL值,请使用UpdateTTL来增大(可无视max限制),或者调整nameserver配置`absolute_ttl_max`和`latest_ttl_max`,重启生效。
diff --git a/docs/zh/openmldb_sql/dml/INSERT_STATEMENT.md b/docs/zh/openmldb_sql/dml/INSERT_STATEMENT.md
index 6ecf98390a3..4799e557577 100644
--- a/docs/zh/openmldb_sql/dml/INSERT_STATEMENT.md
+++ b/docs/zh/openmldb_sql/dml/INSERT_STATEMENT.md
@@ -5,7 +5,7 @@ OpenMLDB 支持一次插入单行或多行数据。
## syntax
```
-INSERT INFO tbl_name (column_list) VALUES (value_list) [, value_list ...]
+INSERT [[OR] IGNORE] INTO tbl_name (column_list) VALUES (value_list) [, value_list ...]
column_list:
col_name [, col_name] ...
@@ -16,6 +16,7 @@ value_list:
**说明**
- `INSERT` 只能用在在线模式
+- 默认`INSERT`不会去重,`INSERT OR IGNORE` 则可以忽略已存在于表中的数据,可以反复重试。
## Examples
diff --git a/docs/zh/openmldb_sql/dml/LOAD_DATA_STATEMENT.md b/docs/zh/openmldb_sql/dml/LOAD_DATA_STATEMENT.md
index b3c7ffc55bf..d2c456b2913 100644
--- a/docs/zh/openmldb_sql/dml/LOAD_DATA_STATEMENT.md
+++ b/docs/zh/openmldb_sql/dml/LOAD_DATA_STATEMENT.md
@@ -58,6 +58,7 @@ FilePathPattern
| load_mode | String | cluster | `load_mode='local'`仅支持从csv本地文件导入在线存储, 它通过本地客户端同步插入数据;
`load_mode='cluster'`仅支持集群版, 通过spark插入数据,支持同步或异步模式 |
| thread | Integer | 1 | 仅在本地文件导入时生效,即`load_mode='local'`或者单机版,表示本地插入数据的线程数。 最大值为`50`。 |
| writer_type | String | single | 集群版在线导入中插入数据的writer类型。可选值为`single`和`batch`,默认为`single`。`single`表示数据即读即写,节省内存。`batch`则是将整个rdd分区读完,确认数据类型有效性后,再写入集群,需要更多内存。在部分情况下,`batch`模式有利于筛选未写入的数据,方便重试这部分数据。 |
+| put_if_absent | Boolean | false | 在源数据无重复行也不与表中已有数据重复时,可以使用此选项避免插入重复数据,特别是job失败后可以重试。等价于使用`INSERT OR IGNORE`。更多详情见下文。 |
```{note}
在集群版中,`LOAD DATA INFILE`语句会根据当前执行模式(execute_mode)决定将数据导入到在线或离线存储。单机版中没有存储区别,只会导入到在线存储中,同时也不支持`deep_copy`选项。
@@ -73,6 +74,7 @@ FilePathPattern
所以,请尽量使用绝对路径。单机测试中,本地文件用`file://`开头;生产环境中,推荐使用hdfs等文件系统。
```
+
## SQL语句模版
```sql
@@ -158,3 +160,12 @@ null,null
第二行两列都是两个双引号。
- cluster模式默认quote为`"`,所以这一行是两个空字符串。
- local模式默认quote为`\0`,所以这一行两列都是两个双引号。local模式quote可以配置为`"`,但escape规则是`""`为单个`"`,和Spark不一致,具体见[issue3015](https://github.com/4paradigm/OpenMLDB/issues/3015)。
+
+## PutIfAbsent说明
+
+PutIfAbsent是一个特殊的选项,它可以避免插入重复数据,仅需一个配置,操作简单,特别适合load datajob失败后重试,等价于使用`INSERT OR IGNORE`。如果你想要导入的数据中存在重复,那么通过PutIfAbsent导入,会导致部分数据丢失。如果你需要保留重复数据,不应使用此选项,建议通过其他方式去重后再导入。
+
+PutIfAbsent需要去重这一额外开销,所以,它的性能与去重的复杂度有关:
+
+- 表中只存在ts索引,且同一key+ts的数据量少于10k时(为了精确去重,在同一个key+ts下会逐行对比整行数据),PutIfAbsent的性能表现不会很差,通常导入时间在普通导入时间的2倍以内。
+- 表中如果存在time索引(ts列为空),或者ts索引同一key+ts的数据量大于100k时,PutIfAbsent的性能会很差,导入时间可能超过普通导入时间的10倍,无法正常使用。这样的数据条件下,更建议进行去重后再导入。
diff --git a/docs/zh/quickstart/beginner_must_read.md b/docs/zh/quickstart/beginner_must_read.md
index 60522283942..ad403a6b423 100644
--- a/docs/zh/quickstart/beginner_must_read.md
+++ b/docs/zh/quickstart/beginner_must_read.md
@@ -69,6 +69,16 @@ OpenMLDB是在线离线存储计算分离的,所以,你需要明确自己导
关于如何设计你的数据流入流出,可参考[实时决策系统中 OpenMLDB 的常见架构整合方式](../tutorial/app_arch.md)。
+### 在线表
+
+在线表是存在内存中的数据,同时也会使用硬盘进行备份恢复。在线表的数据,可以通过`select count(*) from t1`来检查条数,或者使用`show table status`来查看表状态(可能有一定延迟,可以稍等再查)。
+
+在线表是可以有多个索引的,通过`desc
`可以查看。写入一条数据时每个索引中都会写入一条,区别是各个索引的分类排序不同。但由于索引还有TTL淘汰机制,各个索引的数据量可能不一致。`select count(*) from t1`和`show table status`的结果是第一个索引的数据量,它并不代表其他索引的数据量。SQL查询会使用哪一个索引,是由SQL Engine选择的最优索引,可以通过SQL物理计划来查看。
+
+建表时,可以指定索引,也可以不指定,不指定时,会默认创建一个索引。如果是默认索引,它无ts列(用当前time作为排序列,我们称为time索引)将会永不淘汰数据,可以以它为标准检查数据量是否准确,但这样的索引会占用太多的内存,目前也不可以删除第一条索引(计划未来支持),可以通过NS Client修改TTL淘汰数据,减少它的内存占用。
+
+time索引(无ts的索引)还会影响PutIfAbsent导入。如果你的数据导入可能中途失败,无其他方法进行删除或去重,想要使用PutIfAbsent来进行导入重试时,请参考[PutIfAbsent说明](../openmldb_sql/dml/LOAD_DATA_STATEMENT.md#putifabsent说明)对自己的数据进行评估,避免PutIfAbsent效率太差。
+
## 源数据
### LOAD DATA
diff --git a/hybridse/include/node/node_manager.h b/hybridse/include/node/node_manager.h
index 6949faf6f88..914feaca435 100644
--- a/hybridse/include/node/node_manager.h
+++ b/hybridse/include/node/node_manager.h
@@ -166,7 +166,7 @@ class NodeManager {
SqlNode *MakeInsertTableNode(const std::string &db_name,
const std::string &table_name,
const ExprListNode *column_names,
- const ExprListNode *values);
+ const ExprListNode *values, InsertStmt::InsertMode insert_mode);
CreateStmt *MakeCreateTableNode(bool op_if_not_exist,
const std::string &db_name,
const std::string &table_name,
diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h
index 8d641ad8283..801b9063aa4 100644
--- a/hybridse/include/node/sql_node.h
+++ b/hybridse/include/node/sql_node.h
@@ -1865,19 +1865,35 @@ class ColumnDefNode : public SqlNode {
class InsertStmt : public SqlNode {
public:
+ // ref zetasql ASTInsertStatement
+ enum InsertMode {
+ DEFAULT_MODE, // plain INSERT
+ REPLACE, // INSERT OR REPLACE
+ UPDATE, // INSERT OR UPDATE
+ IGNORE // INSERT OR IGNORE
+ };
+
InsertStmt(const std::string &db_name,
const std::string &table_name,
const std::vector &columns,
- const std::vector &values)
+ const std::vector &values,
+ InsertMode insert_mode)
: SqlNode(kInsertStmt, 0, 0),
db_name_(db_name),
table_name_(table_name),
columns_(columns),
values_(values),
- is_all_(columns.empty()) {}
+ is_all_(columns.empty()),
+ insert_mode_(insert_mode) {}
- InsertStmt(const std::string &db_name, const std::string &table_name, const std::vector &values)
- : SqlNode(kInsertStmt, 0, 0), db_name_(db_name), table_name_(table_name), values_(values), is_all_(true) {}
+ InsertStmt(const std::string &db_name, const std::string &table_name, const std::vector &values,
+ InsertMode insert_mode)
+ : SqlNode(kInsertStmt, 0, 0),
+ db_name_(db_name),
+ table_name_(table_name),
+ values_(values),
+ is_all_(true),
+ insert_mode_(insert_mode) {}
void Print(std::ostream &output, const std::string &org_tab) const;
const std::string db_name_;
@@ -1885,6 +1901,7 @@ class InsertStmt : public SqlNode {
const std::vector columns_;
const std::vector values_;
const bool is_all_;
+ const InsertMode insert_mode_;
};
class StorageModeNode : public SqlNode {
diff --git a/hybridse/src/node/node_manager.cc b/hybridse/src/node/node_manager.cc
index 86d51249e19..192a5214f4a 100644
--- a/hybridse/src/node/node_manager.cc
+++ b/hybridse/src/node/node_manager.cc
@@ -792,9 +792,10 @@ AllNode *NodeManager::MakeAllNode(const std::string &relation_name, const std::s
}
SqlNode *NodeManager::MakeInsertTableNode(const std::string &db_name, const std::string &table_name,
- const ExprListNode *columns_expr, const ExprListNode *values) {
+ const ExprListNode *columns_expr, const ExprListNode *values,
+ InsertStmt::InsertMode insert_mode) {
if (nullptr == columns_expr) {
- InsertStmt *node_ptr = new InsertStmt(db_name, table_name, values->children_);
+ InsertStmt *node_ptr = new InsertStmt(db_name, table_name, values->children_, insert_mode);
return RegisterNode(node_ptr);
} else {
std::vector column_names;
@@ -811,7 +812,7 @@ SqlNode *NodeManager::MakeInsertTableNode(const std::string &db_name, const std:
}
}
}
- InsertStmt *node_ptr = new InsertStmt(db_name, table_name, column_names, values->children_);
+ InsertStmt *node_ptr = new InsertStmt(db_name, table_name, column_names, values->children_, insert_mode);
return RegisterNode(node_ptr);
}
}
diff --git a/hybridse/src/node/sql_node_test.cc b/hybridse/src/node/sql_node_test.cc
index e2938656dcc..f047a9cd737 100644
--- a/hybridse/src/node/sql_node_test.cc
+++ b/hybridse/src/node/sql_node_test.cc
@@ -308,7 +308,8 @@ TEST_F(SqlNodeTest, MakeInsertNodeTest) {
value_expr_list->PushBack(value4);
ExprListNode *insert_values = node_manager_->MakeExprList();
insert_values->PushBack(value_expr_list);
- SqlNode *node_ptr = node_manager_->MakeInsertTableNode("", "t1", column_expr_list, insert_values);
+ SqlNode *node_ptr = node_manager_->MakeInsertTableNode("", "t1", column_expr_list, insert_values,
+ InsertStmt::InsertMode::DEFAULT_MODE);
ASSERT_EQ(kInsertStmt, node_ptr->GetType());
InsertStmt *insert_stmt = dynamic_cast(node_ptr);
diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc
index 5d9eb939113..53580b27874 100644
--- a/hybridse/src/planv2/ast_node_converter.cc
+++ b/hybridse/src/planv2/ast_node_converter.cc
@@ -1962,8 +1962,9 @@ base::Status ConvertInsertStatement(const zetasql::ASTInsertStatement* root, nod
}
CHECK_TRUE(nullptr == root->query(), common::kSqlAstError, "Un-support insert statement with query");
- CHECK_TRUE(zetasql::ASTInsertStatement::InsertMode::DEFAULT_MODE == root->insert_mode(), common::kSqlAstError,
- "Un-support insert mode ", root->GetSQLForInsertMode());
+ CHECK_TRUE(zetasql::ASTInsertStatement::InsertMode::DEFAULT_MODE == root->insert_mode() ||
+ zetasql::ASTInsertStatement::InsertMode::IGNORE == root->insert_mode(),
+ common::kSqlAstError, "Un-support insert mode ", root->GetSQLForInsertMode());
CHECK_TRUE(nullptr == root->returning(), common::kSqlAstError,
"Un-support insert statement with return clause currently", root->GetSQLForInsertMode());
CHECK_TRUE(nullptr == root->assert_rows_modified(), common::kSqlAstError,
@@ -2000,8 +2001,8 @@ base::Status ConvertInsertStatement(const zetasql::ASTInsertStatement* root, nod
if (names.size() == 2) {
db_name = names[0];
}
- *output =
- dynamic_cast(node_manager->MakeInsertTableNode(db_name, table_name, column_list, rows));
+ *output = dynamic_cast(node_manager->MakeInsertTableNode(
+ db_name, table_name, column_list, rows, static_cast(root->insert_mode())));
return base::Status::OK();
}
base::Status ConvertDropStatement(const zetasql::ASTDropStatement* root, node::NodeManager* node_manager,
diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/LoadDataPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/LoadDataPlan.scala
index a04b46ab650..ec9946b839a 100644
--- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/LoadDataPlan.scala
+++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/LoadDataPlan.scala
@@ -55,16 +55,19 @@ object LoadDataPlan {
loadDataSql)
// write
- logger.info("write data to storage {}, writer[mode {}], is deep? {}", storage, mode, deepCopy.toString)
+ logger.info("write data to storage {}, writer mode {}, is deep {}", storage, mode, deepCopy.toString)
if (storage == "online") { // Import online data
require(deepCopy && mode == "append", "import to online storage, can't do soft copy, and mode must be append")
val writeType = extra.get("writer_type").get
+ val putIfAbsent = extra.get("put_if_absent").get.toBoolean
+ logger.info(s"online write type ${writeType}, put if absent ${putIfAbsent}")
val writeOptions = Map(
"db" -> db,
"table" -> table,
"zkCluster" -> ctx.getConf.openmldbZkCluster,
"zkPath" -> ctx.getConf.openmldbZkRootPath,
- "writerType" -> writeType
+ "writerType" -> writeType,
+ "putIfAbsent" -> putIfAbsent.toString
)
df.write.options(writeOptions).format("openmldb").mode(mode).save()
} else { // Import offline data
diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala
index 8bf6897d82f..6f3e5b78d40 100644
--- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala
+++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala
@@ -247,16 +247,17 @@ object HybridseUtil {
}
// extra options for some special case
- // only for PhysicalLoadDataNode
var extraOptions: mutable.Map[String, String] = mutable.Map()
+ // only for PhysicalLoadDataNode
extraOptions += ("deep_copy" -> parseOption(getOptionFromNode(node, "deep_copy"), "true", getBoolOrDefault))
-
- // only for select into, "" means N/A
- extraOptions += ("coalesce" -> parseOption(getOptionFromNode(node, "coalesce"), "0", getIntOrDefault))
- extraOptions += ("sql" -> parseOption(getOptionFromNode(node, "sql"), "", getStringOrDefault))
extraOptions += ("writer_type") -> parseOption(getOptionFromNode(node, "writer_type"), "single",
getStringOrDefault)
+ extraOptions += ("sql" -> parseOption(getOptionFromNode(node, "sql"), "", getStringOrDefault))
+ extraOptions += ("put_if_absent" -> parseOption(getOptionFromNode(node, "put_if_absent"), "false",
+ getBoolOrDefault))
+ // only for select into, "" means N/A
+ extraOptions += ("coalesce" -> parseOption(getOptionFromNode(node, "coalesce"), "0", getIntOrDefault))
extraOptions += ("create_if_not_exists" -> parseOption(getOptionFromNode(node, "create_if_not_exists"),
"true", getBoolOrDefault))
diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLConnection.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLConnection.java
index 5383eaf246d..8259682755d 100644
--- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLConnection.java
+++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLConnection.java
@@ -82,7 +82,8 @@ public java.sql.Statement createStatement() throws SQLException {
@Override
public java.sql.PreparedStatement prepareStatement(String sql) throws SQLException {
String lower = sql.toLowerCase();
- if (lower.startsWith("insert into")) {
+ // insert, insert or xxx
+ if (lower.startsWith("insert ")) {
return client.getInsertPreparedStmt(this.defaultDatabase, sql);
} else if (lower.startsWith("select")) {
return client.getPreparedStatement(this.defaultDatabase, sql);
diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/SdkOption.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/SdkOption.java
index eca5289bf32..66d0d83bef9 100644
--- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/SdkOption.java
+++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/SdkOption.java
@@ -17,13 +17,14 @@
package com._4paradigm.openmldb.sdk;
import lombok.Data;
+import java.io.Serializable;
import com._4paradigm.openmldb.BasicRouterOptions;
import com._4paradigm.openmldb.SQLRouterOptions;
import com._4paradigm.openmldb.StandaloneOptions;
@Data
-public class SdkOption {
+public class SdkOption implements Serializable {
// TODO(hw): set isClusterMode automatically
private boolean isClusterMode = true;
// options for cluster mode
diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java
index 6acefe8acff..ecc39b467c1 100644
--- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java
+++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java
@@ -319,7 +319,7 @@ public boolean execute() throws SQLException {
// actually only one row
boolean ok = router.ExecuteInsert(cache.getDatabase(), cache.getName(),
cache.getTid(), cache.getPartitionNum(),
- dimensions.array(), dimensions.capacity(), value.array(), value.capacity(), status);
+ dimensions.array(), dimensions.capacity(), value.array(), value.capacity(), cache.isPutIfAbsent(), status);
// cleanup rows even if insert failed
// we can't execute() again without set new row, so we must clean up here
clearParameters();
@@ -381,7 +381,7 @@ public int[] executeBatch() throws SQLException {
boolean ok = router.ExecuteInsert(cache.getDatabase(), cache.getName(),
cache.getTid(), cache.getPartitionNum(),
pair.getKey().array(), pair.getKey().capacity(),
- pair.getValue().array(), pair.getValue().capacity(), status);
+ pair.getValue().array(), pair.getValue().capacity(), cache.isPutIfAbsent(), status);
if (!ok) {
// TODO(hw): may lost log, e.g. openmldb-batch online import in yarn mode?
logger.warn(status.ToString());
diff --git a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementMeta.java b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementMeta.java
index 448438e9d31..cf2bd05cb58 100644
--- a/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementMeta.java
+++ b/java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementMeta.java
@@ -31,6 +31,7 @@ public class InsertPreparedStatementMeta {
private Set indexPos = new HashSet<>();
private Map> indexMap = new HashMap<>();
private Map defaultIndexValue = new HashMap<>();
+ private boolean putIfAbsent;
public InsertPreparedStatementMeta(String sql, NS.TableInfo tableInfo, SQLInsertRow insertRow) {
this.sql = sql;
@@ -51,6 +52,7 @@ public InsertPreparedStatementMeta(String sql, NS.TableInfo tableInfo, SQLInsert
VectorUint32 idxArray = insertRow.GetHoleIdx();
buildHoleIdx(idxArray);
idxArray.delete();
+ putIfAbsent = insertRow.IsPutIfAbsent();
}
private void buildIndex(NS.TableInfo tableInfo) {
@@ -215,4 +217,8 @@ Map> getIndexMap() {
Map getDefaultIndexValue() {
return defaultIndexValue;
}
+
+ public boolean isPutIfAbsent() {
+ return putIfAbsent;
+ }
}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbConfig.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbConfig.java
new file mode 100644
index 00000000000..7c0981d0a6c
--- /dev/null
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbConfig.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com._4paradigm.openmldb.spark;
+
+import com._4paradigm.openmldb.sdk.SdkOption;
+
+import java.io.Serializable;
+
+import org.sparkproject.guava.base.Preconditions;
+
+// Must serializable
+public class OpenmldbConfig implements Serializable {
+ public final static String DB = "db";
+ public final static String TABLE = "table";
+ public final static String ZK_CLUSTER = "zkCluster";
+ public final static String ZK_PATH = "zkPath";
+
+ /* read&write */
+ private String dbName;
+ private String tableName;
+ private SdkOption option = null;
+
+ /* write */
+ // single: insert when read one row
+ // batch: insert when commit(after read a whole partition)
+ private String writerType = "single";
+ private int insertMemoryUsageLimit = 0;
+ private boolean putIfAbsent = false;
+
+ public OpenmldbConfig() {
+ }
+
+ public void setDB(String dbName) {
+ Preconditions.checkArgument(dbName != null && !dbName.isEmpty(), "db name must not be empty");
+ this.dbName = dbName;
+ }
+
+ public String getDB() {
+ return this.dbName;
+ }
+
+ public void setTable(String tableName) {
+ Preconditions.checkArgument(tableName != null && !tableName.isEmpty(), "table name must not be empty");
+ this.tableName = tableName;
+ }
+
+ public String getTable() {
+ return this.tableName;
+ }
+
+ public void setSdkOption(SdkOption option) {
+ this.option = option;
+ }
+
+ public SdkOption getSdkOption() {
+ return this.option;
+ }
+
+ public void setWriterType(String string) {
+ Preconditions.checkArgument(string.equals("single") || string.equals("batch"),
+ "writerType must be 'single' or 'batch'");
+ this.writerType = string;
+ }
+
+ public void setInsertMemoryUsageLimit(int int1) {
+ Preconditions.checkArgument(int1 >= 0, "insert_memory_usage_limit must be >= 0");
+ this.insertMemoryUsageLimit = int1;
+ }
+
+ public void setPutIfAbsent(Boolean valueOf) {
+ this.putIfAbsent = valueOf;
+ }
+
+ public boolean isBatchWriter() {
+ return this.writerType.equals("batch");
+ }
+
+ public boolean putIfAbsent() {
+ return this.putIfAbsent;
+ }
+
+ public int getInsertMemoryUsageLimit() {
+ return this.insertMemoryUsageLimit;
+ }
+
+}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbSource.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbSource.java
index 7e626f623ea..9dfe78f0197 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbSource.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbSource.java
@@ -18,7 +18,6 @@
package com._4paradigm.openmldb.spark;
import com._4paradigm.openmldb.sdk.SdkOption;
-import com.google.common.base.Preconditions;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.expressions.Transform;
@@ -29,32 +28,20 @@
import java.util.Map;
public class OpenmldbSource implements TableProvider, DataSourceRegister {
- private final String DB = "db";
- private final String TABLE = "table";
- private final String ZK_CLUSTER = "zkCluster";
- private final String ZK_PATH = "zkPath";
- private String dbName;
- private String tableName;
- private SdkOption option = null;
- // single: insert when read one row
- // batch: insert when commit(after read a whole partition)
- private String writerType = "single";
- private int insertMemoryUsageLimit = 0;
+ private OpenmldbConfig config = new OpenmldbConfig();
@Override
public StructType inferSchema(CaseInsensitiveStringMap options) {
- Preconditions.checkNotNull(dbName = options.get(DB));
- Preconditions.checkNotNull(tableName = options.get(TABLE));
+ config.setDB(options.get(OpenmldbConfig.DB));
+ config.setTable(options.get(OpenmldbConfig.TABLE));
- String zkCluster = options.get(ZK_CLUSTER);
- String zkPath = options.get(ZK_PATH);
- Preconditions.checkNotNull(zkCluster);
- Preconditions.checkNotNull(zkPath);
- option = new SdkOption();
- option.setZkCluster(zkCluster);
- option.setZkPath(zkPath);
+ SdkOption option = new SdkOption();
+ option.setZkCluster(options.get(OpenmldbConfig.ZK_CLUSTER));
+ option.setZkPath(options.get(OpenmldbConfig.ZK_PATH));
option.setLight(true);
+ config.setSdkOption(option);
+
String timeout = options.get("sessionTimeout");
if (timeout != null) {
option.setSessionTimeout(Integer.parseInt(timeout));
@@ -69,18 +56,21 @@ public StructType inferSchema(CaseInsensitiveStringMap options) {
}
if (options.containsKey("writerType")) {
- writerType = options.get("writerType");
+ config.setWriterType(options.get("writerType"));
+ }
+ if (options.containsKey("putIfAbsent")) {
+ config.setPutIfAbsent(Boolean.valueOf(options.get("putIfAbsent")));
}
if (options.containsKey("insert_memory_usage_limit")) {
- insertMemoryUsageLimit = Integer.parseInt(options.get("insert_memory_usage_limit"));
+ config.setInsertMemoryUsageLimit(Integer.parseInt(options.get("insert_memory_usage_limit")));
}
return null;
}
@Override
public Table getTable(StructType schema, Transform[] partitioning, Map properties) {
- return new OpenmldbTable(dbName, tableName, option, writerType, insertMemoryUsageLimit);
+ return new OpenmldbTable(config);
}
@Override
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbTable.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbTable.java
index 481a9cc1f4c..e5cbcfe40ca 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbTable.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/OpenmldbTable.java
@@ -22,10 +22,8 @@
import com._4paradigm.openmldb.sdk.SqlException;
import com._4paradigm.openmldb.sdk.SqlExecutor;
import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor;
-import com._4paradigm.openmldb.spark.read.OpenmldbReadConfig;
import com._4paradigm.openmldb.spark.read.OpenmldbScanBuilder;
import com._4paradigm.openmldb.spark.write.OpenmldbWriteBuilder;
-import com._4paradigm.openmldb.spark.write.OpenmldbWriteConfig;
import org.apache.spark.sql.connector.catalog.SupportsRead;
import org.apache.spark.sql.connector.catalog.SupportsWrite;
import org.apache.spark.sql.connector.catalog.TableCapability;
@@ -45,40 +43,32 @@
import java.util.Set;
public class OpenmldbTable implements SupportsWrite, SupportsRead {
- private final String dbName;
- private final String tableName;
- private final SdkOption option;
- private final String writerType;
- private final int insertMemoryUsageLimit;
- private SqlExecutor executor = null;
+ private OpenmldbConfig config;
+ private SqlExecutor executor;
private Set capabilities;
- public OpenmldbTable(String dbName, String tableName, SdkOption option, String writerType, int insertMemoryUsageLimit) {
- this.dbName = dbName;
- this.tableName = tableName;
- this.option = option;
- this.writerType = writerType;
- this.insertMemoryUsageLimit = insertMemoryUsageLimit;
+ public OpenmldbTable(OpenmldbConfig config) {
+ this.config = config;
try {
- this.executor = new SqlClusterExecutor(option);
+ this.executor = new SqlClusterExecutor(config.getSdkOption());
// no need to check table exists, schema() will check it later
} catch (SqlException e) {
e.printStackTrace();
+ throw new RuntimeException("conn openmldb failed", e);
}
// TODO: cache schema & delete executor?
}
@Override
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
- OpenmldbWriteConfig config = new OpenmldbWriteConfig(dbName, tableName, option, writerType, insertMemoryUsageLimit);
return new OpenmldbWriteBuilder(config, info);
}
@Override
public String name() {
// TODO(hw): db?
- return tableName;
+ return config.getTable();
}
public static DataType sdkTypeToSparkType(int sqlType) {
@@ -109,7 +99,7 @@ public static DataType sdkTypeToSparkType(int sqlType) {
@Override
public StructType schema() {
try {
- Schema schema = executor.getTableSchema(dbName, tableName);
+ Schema schema = executor.getTableSchema(config.getDB(), config.getTable());
List schemaList = schema.getColumnList();
StructField[] fields = new StructField[schemaList.size()];
for (int i = 0; i < schemaList.size(); i++) {
@@ -136,7 +126,6 @@ public Set capabilities() {
@Override
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap caseInsensitiveStringMap) {
- OpenmldbReadConfig config = new OpenmldbReadConfig(dbName, tableName, option);
return new OpenmldbScanBuilder(config);
}
}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReaderFactory.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReaderFactory.java
index d5e435fc247..929d30b728e 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReaderFactory.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReaderFactory.java
@@ -17,15 +17,16 @@
package com._4paradigm.openmldb.spark.read;
+import com._4paradigm.openmldb.spark.OpenmldbConfig;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.PartitionReader;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
public class OpenmldbPartitionReaderFactory implements PartitionReaderFactory {
- private final OpenmldbReadConfig config;
+ private final OpenmldbConfig config;
- public OpenmldbPartitionReaderFactory(OpenmldbReadConfig config) {
+ public OpenmldbPartitionReaderFactory(OpenmldbConfig config) {
this.config = config;
}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbReadConfig.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbReadConfig.java
deleted file mode 100644
index 91489888ba9..00000000000
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbReadConfig.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com._4paradigm.openmldb.spark.read;
-
-import com._4paradigm.openmldb.sdk.SdkOption;
-import java.io.Serializable;
-
-// Must serializable
-public class OpenmldbReadConfig implements Serializable {
- public final String dbName, tableName, zkCluster, zkPath;
-
- public OpenmldbReadConfig(String dbName, String tableName, SdkOption option) {
- this.dbName = dbName;
- this.tableName = tableName;
- this.zkCluster = option.getZkCluster();
- this.zkPath = option.getZkPath();
- // TODO(hw): other configs in SdkOption
- }
-}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScan.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScan.java
index fb7adb46b8e..4eeac9a6013 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScan.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScan.java
@@ -17,6 +17,7 @@
package com._4paradigm.openmldb.spark.read;
+import com._4paradigm.openmldb.spark.OpenmldbConfig;
import org.apache.spark.sql.connector.read.Batch;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
@@ -24,9 +25,9 @@
import org.apache.spark.sql.types.StructType;
public class OpenmldbScan implements Scan, Batch {
- private final OpenmldbReadConfig config;
+ private final OpenmldbConfig config;
- public OpenmldbScan(OpenmldbReadConfig config) {
+ public OpenmldbScan(OpenmldbConfig config) {
this.config = config;
}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScanBuilder.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScanBuilder.java
index 2b500a6592e..de59a811f46 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScanBuilder.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/read/OpenmldbScanBuilder.java
@@ -17,13 +17,14 @@
package com._4paradigm.openmldb.spark.read;
+import com._4paradigm.openmldb.spark.OpenmldbConfig;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.ScanBuilder;
public class OpenmldbScanBuilder implements ScanBuilder {
- private final OpenmldbReadConfig config;
+ private final OpenmldbConfig config;
- public OpenmldbScanBuilder(OpenmldbReadConfig config) {
+ public OpenmldbScanBuilder(OpenmldbConfig config) {
this.config = config;
}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbBatchWrite.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbBatchWrite.java
index ca90a07d63a..d19fd9f6aeb 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbBatchWrite.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbBatchWrite.java
@@ -17,6 +17,7 @@
package com._4paradigm.openmldb.spark.write;
+import com._4paradigm.openmldb.spark.OpenmldbConfig;
import org.apache.spark.sql.connector.write.BatchWrite;
import org.apache.spark.sql.connector.write.DataWriterFactory;
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
@@ -24,10 +25,10 @@
import org.apache.spark.sql.connector.write.WriterCommitMessage;
public class OpenmldbBatchWrite implements BatchWrite {
- private final OpenmldbWriteConfig config;
+ private final OpenmldbConfig config;
private final LogicalWriteInfo info;
- public OpenmldbBatchWrite(OpenmldbWriteConfig config, LogicalWriteInfo info) {
+ public OpenmldbBatchWrite(OpenmldbConfig config, LogicalWriteInfo info) {
this.config = config;
this.info = info;
}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataSingleWriter.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataSingleWriter.java
index 843fb9a8da7..cc5f0150cc3 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataSingleWriter.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataSingleWriter.java
@@ -17,8 +17,9 @@
package com._4paradigm.openmldb.spark.write;
+import com._4paradigm.openmldb.spark.OpenmldbConfig;
+
import com._4paradigm.openmldb.sdk.Schema;
-import com._4paradigm.openmldb.sdk.SdkOption;
import com._4paradigm.openmldb.sdk.SqlException;
import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor;
import com.google.common.base.Preconditions;
@@ -27,32 +28,26 @@
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import java.io.IOException;
-import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
-import java.sql.Timestamp;
-import java.sql.Types;
public class OpenmldbDataSingleWriter implements DataWriter {
private final int partitionId;
private final long taskId;
private PreparedStatement preparedStatement = null;
- public OpenmldbDataSingleWriter(OpenmldbWriteConfig config, int partitionId, long taskId) {
+ public OpenmldbDataSingleWriter(OpenmldbConfig config, int partitionId, long taskId) {
try {
- SdkOption option = new SdkOption();
- option.setZkCluster(config.zkCluster);
- option.setZkPath(config.zkPath);
- option.setLight(true);
- SqlClusterExecutor executor = new SqlClusterExecutor(option);
- String dbName = config.dbName;
- String tableName = config.tableName;
- executor.executeSQL(dbName, "SET @@insert_memory_usage_limit=" + config.insertMemoryUsageLimit);
+ SqlClusterExecutor executor = new SqlClusterExecutor(config.getSdkOption());
+ String dbName = config.getDB();
+ String tableName = config.getTable();
+ executor.executeSQL(dbName, "SET @@insert_memory_usage_limit=" + config.getInsertMemoryUsageLimit());
Schema schema = executor.getTableSchema(dbName, tableName);
// create insert placeholder
- StringBuilder insert = new StringBuilder("insert into " + tableName + " values(?");
+ String insert_part = config.putIfAbsent()? "insert or ignore into " : "insert into ";
+ StringBuilder insert = new StringBuilder(insert_part + tableName + " values(?");
for (int i = 1; i < schema.getColumnList().size(); i++) {
insert.append(",?");
}
@@ -60,6 +55,7 @@ public OpenmldbDataSingleWriter(OpenmldbWriteConfig config, int partitionId, lon
preparedStatement = executor.getInsertPreparedStmt(dbName, insert.toString());
} catch (SQLException | SqlException e) {
e.printStackTrace();
+ throw new RuntimeException("create openmldb writer failed", e);
}
this.partitionId = partitionId;
@@ -73,7 +69,12 @@ public void write(InternalRow record) throws IOException {
ResultSetMetaData metaData = preparedStatement.getMetaData();
Preconditions.checkState(record.numFields() == metaData.getColumnCount());
OpenmldbDataWriter.addRow(record, preparedStatement);
- preparedStatement.execute();
+ // check return for put result
+ // you can cache failed rows and throw exception when commit/close,
+ // but it still may interrupt other writers(pending or slow writers)
+ if(!preparedStatement.execute()) {
+ throw new IOException("execute failed");
+ }
} catch (Exception e) {
throw new IOException("write row to openmldb failed on " + record, e);
}
@@ -81,24 +82,13 @@ public void write(InternalRow record) throws IOException {
@Override
public WriterCommitMessage commit() throws IOException {
- try {
- preparedStatement.close();
- } catch (SQLException e) {
- e.printStackTrace();
- throw new IOException("commit error", e);
- }
- // TODO(hw): need to return new WriterCommitMessageImpl(partitionId, taskId); ?
+ // no transaction, no commit
return null;
}
@Override
public void abort() throws IOException {
- try {
- preparedStatement.close();
- } catch (SQLException e) {
- e.printStackTrace();
- throw new IOException("abort error", e);
- }
+ // no transaction, no abort
}
@Override
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriter.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriter.java
index e9ba0e30c5a..65bc2e5a457 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriter.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriter.java
@@ -17,6 +17,8 @@
package com._4paradigm.openmldb.spark.write;
+import com._4paradigm.openmldb.spark.OpenmldbConfig;
+
import com._4paradigm.openmldb.sdk.Schema;
import com._4paradigm.openmldb.sdk.SdkOption;
import com._4paradigm.openmldb.sdk.SqlException;
@@ -39,20 +41,17 @@ public class OpenmldbDataWriter implements DataWriter {
private final long taskId;
private PreparedStatement preparedStatement = null;
- public OpenmldbDataWriter(OpenmldbWriteConfig config, int partitionId, long taskId) {
+ public OpenmldbDataWriter(OpenmldbConfig config, int partitionId, long taskId) {
try {
- SdkOption option = new SdkOption();
- option.setZkCluster(config.zkCluster);
- option.setZkPath(config.zkPath);
- option.setLight(true);
- SqlClusterExecutor executor = new SqlClusterExecutor(option);
- String dbName = config.dbName;
- String tableName = config.tableName;
- executor.executeSQL(dbName, "SET @@insert_memory_usage_limit=" + config.insertMemoryUsageLimit);
+ SqlClusterExecutor executor = new SqlClusterExecutor(config.getSdkOption());
+ String dbName = config.getDB();
+ String tableName = config.getTable();
+ executor.executeSQL(dbName, "SET @@insert_memory_usage_limit=" + config.getInsertMemoryUsageLimit());
Schema schema = executor.getTableSchema(dbName, tableName);
// create insert placeholder
- StringBuilder insert = new StringBuilder("insert into " + tableName + " values(?");
+ String insert_part = config.putIfAbsent()? "insert or ignore into " : "insert into ";
+ StringBuilder insert = new StringBuilder(insert_part + tableName + " values(?");
for (int i = 1; i < schema.getColumnList().size(); i++) {
insert.append(",?");
}
@@ -60,6 +59,7 @@ public OpenmldbDataWriter(OpenmldbWriteConfig config, int partitionId, long task
preparedStatement = executor.getInsertPreparedStmt(dbName, insert.toString());
} catch (SQLException | SqlException e) {
e.printStackTrace();
+ throw new RuntimeException("create openmldb data writer failed", e);
}
this.partitionId = partitionId;
@@ -147,12 +147,7 @@ public WriterCommitMessage commit() throws IOException {
@Override
public void abort() throws IOException {
- try {
- preparedStatement.close();
- } catch (SQLException e) {
- e.printStackTrace();
- throw new IOException("abort error", e);
- }
+ // no transaction, no abort
}
@Override
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriterFactory.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriterFactory.java
index 96e78979b2f..12cefb3928b 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriterFactory.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbDataWriterFactory.java
@@ -17,20 +17,21 @@
package com._4paradigm.openmldb.spark.write;
+import com._4paradigm.openmldb.spark.OpenmldbConfig;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.DataWriterFactory;
public class OpenmldbDataWriterFactory implements DataWriterFactory {
- private final OpenmldbWriteConfig config;
+ private final OpenmldbConfig config;
- public OpenmldbDataWriterFactory(OpenmldbWriteConfig config) {
+ public OpenmldbDataWriterFactory(OpenmldbConfig config) {
this.config = config;
}
@Override
public DataWriter createWriter(int partitionId, long taskId) {
- if (!config.writerType.equals("batch")) {
+ if (!config.isBatchWriter()) {
return new OpenmldbDataSingleWriter(config, partitionId, taskId);
}
return new OpenmldbDataWriter(config, partitionId, taskId);
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteBuilder.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteBuilder.java
index a3c905b15c1..ccd588df0c4 100644
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteBuilder.java
+++ b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteBuilder.java
@@ -17,15 +17,16 @@
package com._4paradigm.openmldb.spark.write;
+import com._4paradigm.openmldb.spark.OpenmldbConfig;
import org.apache.spark.sql.connector.write.BatchWrite;
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
import org.apache.spark.sql.connector.write.WriteBuilder;
public class OpenmldbWriteBuilder implements WriteBuilder {
- private final OpenmldbWriteConfig config;
+ private final OpenmldbConfig config;
private final LogicalWriteInfo info;
- public OpenmldbWriteBuilder(OpenmldbWriteConfig config, LogicalWriteInfo info) {
+ public OpenmldbWriteBuilder(OpenmldbConfig config, LogicalWriteInfo info) {
this.config = config;
this.info = info;
}
diff --git a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteConfig.java b/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteConfig.java
deleted file mode 100644
index 80007b14ae5..00000000000
--- a/java/openmldb-spark-connector/src/main/java/com/_4paradigm/openmldb/spark/write/OpenmldbWriteConfig.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com._4paradigm.openmldb.spark.write;
-
-import com._4paradigm.openmldb.sdk.SdkOption;
-
-import java.io.Serializable;
-
-// Must serializable
-public class OpenmldbWriteConfig implements Serializable {
- public final String dbName, tableName, zkCluster, zkPath, writerType;
- public final int insertMemoryUsageLimit;
-
- public OpenmldbWriteConfig(String dbName, String tableName, SdkOption option, String writerType, int insertMemoryUsageLimit) {
- this.dbName = dbName;
- this.tableName = tableName;
- this.zkCluster = option.getZkCluster();
- this.zkPath = option.getZkPath();
- this.writerType = writerType;
- this.insertMemoryUsageLimit = insertMemoryUsageLimit;
- // TODO(hw): other configs in SdkOption
- }
-}
diff --git a/java/openmldb-spark-connector/src/main/scala/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReader.scala b/java/openmldb-spark-connector/src/main/scala/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReader.scala
index d8eeb89e7ab..86921d0f4d5 100644
--- a/java/openmldb-spark-connector/src/main/scala/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReader.scala
+++ b/java/openmldb-spark-connector/src/main/scala/com/_4paradigm/openmldb/spark/read/OpenmldbPartitionReader.scala
@@ -1,5 +1,6 @@
package com._4paradigm.openmldb.spark.read
+import com._4paradigm.openmldb.spark.OpenmldbConfig
import com._4paradigm.openmldb.sdk.{Schema, SdkOption}
import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor
import org.apache.spark.sql.catalyst.InternalRow
@@ -8,15 +9,10 @@ import org.apache.spark.unsafe.types.UTF8String
import java.sql.Types
-class OpenmldbPartitionReader(config: OpenmldbReadConfig) extends PartitionReader[InternalRow] {
-
- val option = new SdkOption
- option.setZkCluster(config.zkCluster)
- option.setZkPath(config.zkPath)
- option.setLight(true)
- val executor = new SqlClusterExecutor(option)
- val dbName: String = config.dbName
- val tableName: String = config.tableName
+class OpenmldbPartitionReader(config: OpenmldbConfig) extends PartitionReader[InternalRow] {
+ val executor = new SqlClusterExecutor(config.getSdkOption)
+ val dbName: String = config.getDB
+ val tableName: String = config.getTable
val schema: Schema = executor.getTableSchema(dbName, tableName)
executor.executeSQL(dbName, "SET @@execute_mode='online'")
diff --git a/src/client/tablet_client.cc b/src/client/tablet_client.cc
index 54b2a8c9cec..f445cc1791c 100644
--- a/src/client/tablet_client.cc
+++ b/src/client/tablet_client.cc
@@ -203,19 +203,21 @@ bool TabletClient::UpdateTableMetaForAddField(uint32_t tid, const std::vector>& dimensions,
- int memory_usage_limit) {
+ int memory_usage_limit, bool put_if_absent) {
+
::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension> pb_dimensions;
for (size_t i = 0; i < dimensions.size(); i++) {
::openmldb::api::Dimension* d = pb_dimensions.Add();
d->set_key(dimensions[i].first);
d->set_idx(dimensions[i].second);
}
- return Put(tid, pid, time, base::Slice(value), &pb_dimensions, memory_usage_limit);
+
+ return Put(tid, pid, time, base::Slice(value), &pb_dimensions, memory_usage_limit, put_if_absent);
}
base::Status TabletClient::Put(uint32_t tid, uint32_t pid, uint64_t time, const base::Slice& value,
::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension>* dimensions,
- int memory_usage_limit) {
+ int memory_usage_limit, bool put_if_absent) {
::openmldb::api::PutRequest request;
if (memory_usage_limit < 0 || memory_usage_limit > 100) {
return {base::ReturnCode::kError, absl::StrCat("invalid memory_usage_limit ", memory_usage_limit)};
@@ -227,6 +229,7 @@ base::Status TabletClient::Put(uint32_t tid, uint32_t pid, uint64_t time, const
request.set_tid(tid);
request.set_pid(pid);
request.mutable_dimensions()->Swap(dimensions);
+ request.set_put_if_absent(put_if_absent);
::openmldb::api::PutResponse response;
auto st = client_.SendRequestSt(&::openmldb::api::TabletServer_Stub::Put,
&request, &response, FLAGS_request_timeout_ms, 1);
diff --git a/src/client/tablet_client.h b/src/client/tablet_client.h
index b4866b77618..19579f90c5c 100644
--- a/src/client/tablet_client.h
+++ b/src/client/tablet_client.h
@@ -76,20 +76,18 @@ class TabletClient : public Client {
base::Status Put(uint32_t tid, uint32_t pid, uint64_t time, const std::string& value,
const std::vector>& dimensions,
- int memory_usage_limit = 0);
+ int memory_usage_limit = 0, bool put_if_absent = false);
base::Status Put(uint32_t tid, uint32_t pid, uint64_t time, const base::Slice& value,
::google::protobuf::RepeatedPtrField<::openmldb::api::Dimension>* dimensions,
- int memory_usage_limit = 0);
+ int memory_usage_limit = 0, bool put_if_absent = false);
bool Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, std::string& value, // NOLINT
uint64_t& ts, // NOLINT
std::string& msg); // NOLINT
bool Get(uint32_t tid, uint32_t pid, const std::string& pk, uint64_t time, const std::string& idx_name,
- std::string& value, // NOLINT
- uint64_t& ts, // NOLINT
- std::string& msg); // NOLINT
+ std::string& value, uint64_t& ts, std::string& msg); // NOLINT
bool Delete(uint32_t tid, uint32_t pid, const std::string& pk, const std::string& idx_name,
std::string& msg); // NOLINT
diff --git a/src/cmd/openmldb.cc b/src/cmd/openmldb.cc
index 8d0d9b692f5..53da31ad634 100644
--- a/src/cmd/openmldb.cc
+++ b/src/cmd/openmldb.cc
@@ -438,7 +438,7 @@ std::shared_ptr<::openmldb::client::TabletClient> GetTabletClient(const ::openml
void HandleNSClientSetTTL(const std::vector& parts, ::openmldb::client::NsClient* client) {
if (parts.size() < 4) {
- std::cout << "bad setttl format, eg settl t1 absolute 10" << std::endl;
+ std::cout << "bad setttl format, eg settl t1 absolute 10 [index0]" << std::endl;
return;
}
std::string index_name;
@@ -1307,14 +1307,14 @@ void HandleNSGet(const std::vector& parts, ::openmldb::client::NsCl
if (parts.size() < 4) {
std::cout << "get format error. eg: get table_name key ts | get "
"table_name key idx_name ts | get table_name=xxx key=xxx "
- "index_name=xxx ts=xxx ts_name=xxx "
+ "index_name=xxx ts=xxx"
<< std::endl;
return;
}
std::map parameter_map;
if (!GetParameterMap("table_name", parts, "=", parameter_map)) {
std::cout << "get format error. eg: get table_name=xxx key=xxx "
- "index_name=xxx ts=xxx ts_name=xxx "
+ "index_name=xxx ts=xxx"
<< std::endl;
return;
}
@@ -1382,7 +1382,7 @@ void HandleNSGet(const std::vector& parts, ::openmldb::client::NsCl
return;
}
::openmldb::codec::SDKCodec codec(tables[0]);
- bool no_schema = tables[0].column_desc_size() == 0 && tables[0].column_desc_size() == 0;
+ bool no_schema = tables[0].column_desc_size() == 0;
if (no_schema) {
std::string value;
uint64_t ts = 0;
@@ -2459,7 +2459,7 @@ void HandleNSClientHelp(const std::vector& parts, ::openmldb::clien
printf("ex:man create\n");
} else if (parts[1] == "setttl") {
printf("desc: set table ttl \n");
- printf("usage: setttl table_name ttl_type ttl [ts_name]\n");
+ printf("usage: setttl table_name ttl_type ttl [index_name], abs ttl unit is minute\n");
printf("ex: setttl t1 absolute 10\n");
printf("ex: setttl t2 latest 5\n");
printf("ex: setttl t3 latest 5 ts1\n");
diff --git a/src/proto/tablet.proto b/src/proto/tablet.proto
index 2c7a038960b..a18714b2ae1 100755
--- a/src/proto/tablet.proto
+++ b/src/proto/tablet.proto
@@ -195,6 +195,7 @@ message PutRequest {
repeated TSDimension ts_dimensions = 7 [deprecated = true];
optional uint32 format_version = 8 [default = 0, deprecated = true];
optional uint32 memory_limit = 9;
+ optional bool put_if_absent = 10 [default = false];
}
message PutResponse {
diff --git a/src/sdk/node_adapter_test.cc b/src/sdk/node_adapter_test.cc
index e09758b07cd..70c35ff7d9c 100644
--- a/src/sdk/node_adapter_test.cc
+++ b/src/sdk/node_adapter_test.cc
@@ -64,7 +64,7 @@ void CheckTablePartition(const ::openmldb::nameserver::TableInfo& table_info,
if (table_partition.partition_meta(pos).is_leader()) {
ASSERT_EQ(table_partition.partition_meta(pos).endpoint(), leader);
} else {
- ASSERT_EQ(follower.count(table_partition.partition_meta(pos).endpoint()), 1);
+ ASSERT_EQ(follower.count(table_partition.partition_meta(pos).endpoint()), (std::size_t)1);
}
}
}
diff --git a/src/sdk/sql_cache.h b/src/sdk/sql_cache.h
index a326437c10f..1fe0b346fa2 100644
--- a/src/sdk/sql_cache.h
+++ b/src/sdk/sql_cache.h
@@ -54,26 +54,28 @@ class InsertSQLCache : public SQLCache {
InsertSQLCache(const std::shared_ptr<::openmldb::nameserver::TableInfo>& table_info,
const std::shared_ptr<::hybridse::sdk::Schema>& column_schema,
DefaultValueMap default_map,
- uint32_t str_length, std::vector hole_idx_arr)
+ uint32_t str_length, std::vector hole_idx_arr, bool put_if_absent)
: SQLCache(table_info->db(), table_info->tid(), table_info->name()),
table_info_(table_info),
column_schema_(column_schema),
default_map_(std::move(default_map)),
str_length_(str_length),
- hole_idx_arr_(std::move(hole_idx_arr)) {}
+ hole_idx_arr_(std::move(hole_idx_arr)),
+ put_if_absent_(put_if_absent) {}
std::shared_ptr<::openmldb::nameserver::TableInfo> GetTableInfo() { return table_info_; }
std::shared_ptr<::hybridse::sdk::Schema> GetSchema() const { return column_schema_; }
uint32_t GetStrLength() const { return str_length_; }
const DefaultValueMap& GetDefaultValue() const { return default_map_; }
const std::vector& GetHoleIdxArr() const { return hole_idx_arr_; }
-
+ const bool IsPutIfAbsent() const { return put_if_absent_; }
private:
std::shared_ptr<::openmldb::nameserver::TableInfo> table_info_;
std::shared_ptr<::hybridse::sdk::Schema> column_schema_;
const DefaultValueMap default_map_;
const uint32_t str_length_;
const std::vector hole_idx_arr_;
+ const bool put_if_absent_;
};
class RouterSQLCache : public SQLCache {
diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc
index 7c5a98814b9..bdad16cfc8c 100644
--- a/src/sdk/sql_cluster_router.cc
+++ b/src/sdk/sql_cluster_router.cc
@@ -455,39 +455,40 @@ std::shared_ptr SQLClusterRouter::GetInsertRow(const std::string&
*status = {};
return std::make_shared(insert_cache->GetTableInfo(), insert_cache->GetSchema(),
insert_cache->GetDefaultValue(), insert_cache->GetStrLength(),
- insert_cache->GetHoleIdxArr());
+ insert_cache->GetHoleIdxArr(), insert_cache->IsPutIfAbsent());
}
}
std::shared_ptr<::openmldb::nameserver::TableInfo> table_info;
DefaultValueMap default_map;
uint32_t str_length = 0;
std::vector stmt_column_idx_arr;
- if (!GetInsertInfo(db, sql, status, &table_info, &default_map, &str_length, &stmt_column_idx_arr)) {
+ bool put_if_absent = false;
+ if (!GetInsertInfo(db, sql, status, &table_info, &default_map, &str_length, &stmt_column_idx_arr, &put_if_absent)) {
SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "get insert information failed");
return {};
}
auto schema = openmldb::schema::SchemaAdapter::ConvertSchema(table_info->column_desc());
- auto insert_cache =
- std::make_shared(table_info, schema, default_map, str_length,
- SQLInsertRow::GetHoleIdxArr(default_map, stmt_column_idx_arr, schema));
+ auto insert_cache = std::make_shared(
+ table_info, schema, default_map, str_length,
+ SQLInsertRow::GetHoleIdxArr(default_map, stmt_column_idx_arr, schema), put_if_absent);
SetCache(db, sql, hybridse::vm::kBatchMode, insert_cache);
*status = {};
return std::make_shared(insert_cache->GetTableInfo(), insert_cache->GetSchema(),
insert_cache->GetDefaultValue(), insert_cache->GetStrLength(),
- insert_cache->GetHoleIdxArr());
+ insert_cache->GetHoleIdxArr(), insert_cache->IsPutIfAbsent());
}
bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::string& sql,
::hybridse::sdk::Status* status,
std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info,
std::vector* default_maps,
- std::vector* str_lengths) {
+ std::vector* str_lengths, bool* put_if_absent) {
RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr");
// TODO(hw): return status?
RET_FALSE_IF_NULL_AND_WARN(table_info, "output table_info is nullptr");
RET_FALSE_IF_NULL_AND_WARN(default_maps, "output default_maps is nullptr");
RET_FALSE_IF_NULL_AND_WARN(str_lengths, "output str_lengths is nullptr");
-
+ RET_FALSE_IF_NULL_AND_WARN(put_if_absent, "output put_if_absent is nullptr");
::hybridse::node::NodeManager nm;
::hybridse::plan::PlanNodeList plans;
bool ok = GetSQLPlan(sql, &nm, &plans);
@@ -506,6 +507,7 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s
SET_STATUS_AND_WARN(status, StatusCode::kPlanError, "insert stmt is null");
return false;
}
+ *put_if_absent = insert_stmt->insert_mode_ == ::hybridse::node::InsertStmt::IGNORE;
std::string db_name;
if (!insert_stmt->db_name_.empty()) {
db_name = insert_stmt->db_name_;
@@ -576,7 +578,7 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s
bool SQLClusterRouter::GetInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status,
std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info,
DefaultValueMap* default_map, uint32_t* str_length,
- std::vector* stmt_column_idx_in_table) {
+ std::vector* stmt_column_idx_in_table, bool* put_if_absent) {
RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr");
RET_FALSE_IF_NULL_AND_WARN(table_info, "output table_info is nullptr");
RET_FALSE_IF_NULL_AND_WARN(default_map, "output default_map is nullptr");
@@ -636,6 +638,7 @@ bool SQLClusterRouter::GetInsertInfo(const std::string& db, const std::string& s
SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "get default value map of " + sql + " failed");
return false;
}
+ *put_if_absent = insert_stmt->insert_mode_ == ::hybridse::node::InsertStmt::IGNORE;
return true;
}
@@ -771,23 +774,24 @@ std::shared_ptr SQLClusterRouter::GetInsertRows(const std::string
status->SetOK();
return std::make_shared(insert_cache->GetTableInfo(), insert_cache->GetSchema(),
insert_cache->GetDefaultValue(), insert_cache->GetStrLength(),
- insert_cache->GetHoleIdxArr());
+ insert_cache->GetHoleIdxArr(), insert_cache->IsPutIfAbsent());
}
}
std::shared_ptr<::openmldb::nameserver::TableInfo> table_info;
DefaultValueMap default_map;
uint32_t str_length = 0;
std::vector stmt_column_idx_arr;
- if (!GetInsertInfo(db, sql, status, &table_info, &default_map, &str_length, &stmt_column_idx_arr)) {
+ bool put_if_absent = false;
+ if (!GetInsertInfo(db, sql, status, &table_info, &default_map, &str_length, &stmt_column_idx_arr, &put_if_absent)) {
return {};
}
auto col_schema = openmldb::schema::SchemaAdapter::ConvertSchema(table_info->column_desc());
- insert_cache =
- std::make_shared(table_info, col_schema, default_map, str_length,
- SQLInsertRow::GetHoleIdxArr(default_map, stmt_column_idx_arr, col_schema));
+ insert_cache = std::make_shared(
+ table_info, col_schema, default_map, str_length,
+ SQLInsertRow::GetHoleIdxArr(default_map, stmt_column_idx_arr, col_schema), put_if_absent);
SetCache(db, sql, hybridse::vm::kBatchMode, insert_cache);
return std::make_shared(table_info, insert_cache->GetSchema(), default_map, str_length,
- insert_cache->GetHoleIdxArr());
+ insert_cache->GetHoleIdxArr(), insert_cache->IsPutIfAbsent());
}
bool SQLClusterRouter::ExecuteDDL(const std::string& db, const std::string& sql, hybridse::sdk::Status* status) {
@@ -1303,7 +1307,8 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s
std::shared_ptr<::openmldb::nameserver::TableInfo> table_info;
std::vector default_maps;
std::vector str_lengths;
- if (!GetMultiRowInsertInfo(db, sql, status, &table_info, &default_maps, &str_lengths)) {
+ bool put_if_absent;
+ if (!GetMultiRowInsertInfo(db, sql, status, &table_info, &default_maps, &str_lengths, &put_if_absent)) {
CODE_PREPEND_AND_WARN(status, StatusCode::kCmdError, "Fail to get insert info");
return false;
}
@@ -1318,7 +1323,7 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s
}
std::vector fails;
for (size_t i = 0; i < default_maps.size(); i++) {
- auto row = std::make_shared(table_info, schema, default_maps[i], str_lengths[i]);
+ auto row = std::make_shared(table_info, schema, default_maps[i], str_lengths[i], put_if_absent);
if (!row) {
LOG(WARNING) << "fail to parse row[" << i << "]";
fails.push_back(i);
@@ -1368,7 +1373,7 @@ bool SQLClusterRouter::PutRow(uint32_t tid, const std::shared_ptr&
DLOG(INFO) << "put data to endpoint " << client->GetEndpoint() << " with dimensions size "
<< kv.second.size();
auto ret = client->Put(tid, pid, cur_ts, row->GetRow(), kv.second,
- insert_memory_usage_limit_.load(std::memory_order_relaxed));
+ insert_memory_usage_limit_.load(std::memory_order_relaxed), row->IsPutIfAbsent());
if (!ret.OK()) {
if (RevertPut(row->GetTableInfo(), pid, dimensions, cur_ts,
base::Slice(row->GetRow()), tablets).IsOK()) {
@@ -1448,8 +1453,8 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s
}
bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& name, int tid, int partition_num,
- hybridse::sdk::ByteArrayPtr dimension, int dimension_len,
- hybridse::sdk::ByteArrayPtr value, int len, hybridse::sdk::Status* status) {
+ hybridse::sdk::ByteArrayPtr dimension, int dimension_len,
+ hybridse::sdk::ByteArrayPtr value, int len, bool put_if_absent, hybridse::sdk::Status* status) {
RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr");
if (dimension == nullptr || dimension_len <= 0 || value == nullptr || len <= 0 || partition_num <= 0) {
*status = {StatusCode::kCmdError, "invalid parameter"};
@@ -1491,7 +1496,7 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& n
DLOG(INFO) << "put data to endpoint " << client->GetEndpoint() << " with dimensions size "
<< kv.second.size();
auto ret = client->Put(tid, pid, cur_ts, row_value, &kv.second,
- insert_memory_usage_limit_.load(std::memory_order_relaxed));
+ insert_memory_usage_limit_.load(std::memory_order_relaxed), put_if_absent);
if (!ret.OK()) {
SET_STATUS_AND_WARN(status, StatusCode::kCmdError,
"INSERT failed, tid " + std::to_string(tid) +
diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h
index 502ad07dab6..0b9f6cca272 100644
--- a/src/sdk/sql_cluster_router.h
+++ b/src/sdk/sql_cluster_router.h
@@ -87,7 +87,7 @@ class SQLClusterRouter : public SQLRouter {
bool ExecuteInsert(const std::string& db, const std::string& name, int tid, int partition_num,
hybridse::sdk::ByteArrayPtr dimension, int dimension_len,
- hybridse::sdk::ByteArrayPtr value, int len, hybridse::sdk::Status* status) override;
+ hybridse::sdk::ByteArrayPtr value, int len, bool put_if_absent, hybridse::sdk::Status* status) override;
bool ExecuteDelete(std::shared_ptr row, hybridse::sdk::Status* status) override;
@@ -316,10 +316,11 @@ class SQLClusterRouter : public SQLRouter {
bool GetInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status,
std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, DefaultValueMap* default_map,
- uint32_t* str_length, std::vector* stmt_column_idx_in_table);
+ uint32_t* str_length, std::vector* stmt_column_idx_in_table, bool* put_if_absent);
bool GetMultiRowInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status,
std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info,
- std::vector* default_maps, std::vector* str_lengths);
+ std::vector* default_maps, std::vector* str_lengths,
+ bool* put_if_absent);
DefaultValueMap GetDefaultMap(const std::shared_ptr<::openmldb::nameserver::TableInfo>& table_info,
const std::map& column_map, ::hybridse::node::ExprListNode* row,
diff --git a/src/sdk/sql_insert_row.cc b/src/sdk/sql_insert_row.cc
index a2d44571be2..492bb80e49b 100644
--- a/src/sdk/sql_insert_row.cc
+++ b/src/sdk/sql_insert_row.cc
@@ -29,33 +29,35 @@ namespace sdk {
SQLInsertRows::SQLInsertRows(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info,
std::shared_ptr schema, DefaultValueMap default_map,
- uint32_t default_str_length, const std::vector& hole_idx_arr)
+ uint32_t default_str_length, const std::vector& hole_idx_arr, bool put_if_absent)
: table_info_(std::move(table_info)),
schema_(std::move(schema)),
default_map_(std::move(default_map)),
default_str_length_(default_str_length),
- hole_idx_arr_(hole_idx_arr) {}
+ hole_idx_arr_(hole_idx_arr),
+ put_if_absent_(put_if_absent) {}
std::shared_ptr SQLInsertRows::NewRow() {
if (!rows_.empty() && !rows_.back()->IsComplete()) {
return {};
}
- std::shared_ptr row =
- std::make_shared(table_info_, schema_, default_map_, default_str_length_, hole_idx_arr_);
+ std::shared_ptr row = std::make_shared(
+ table_info_, schema_, default_map_, default_str_length_, hole_idx_arr_, put_if_absent_);
rows_.push_back(row);
return row;
}
SQLInsertRow::SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info,
std::shared_ptr schema, DefaultValueMap default_map,
- uint32_t default_string_length)
+ uint32_t default_string_length, bool put_if_absent)
: table_info_(table_info),
schema_(std::move(schema)),
default_map_(std::move(default_map)),
default_string_length_(default_string_length),
rb_(table_info->column_desc()),
val_(),
- str_size_(0) {
+ str_size_(0),
+ put_if_absent_(put_if_absent) {
std::map column_name_map;
for (int idx = 0; idx < table_info_->column_desc_size(); idx++) {
column_name_map.emplace(table_info_->column_desc(idx).name(), idx);
@@ -81,8 +83,9 @@ SQLInsertRow::SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> ta
SQLInsertRow::SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info,
std::shared_ptr schema, DefaultValueMap default_map,
- uint32_t default_str_length, std::vector hole_idx_arr)
- : SQLInsertRow(std::move(table_info), std::move(schema), std::move(default_map), default_str_length) {
+ uint32_t default_str_length, std::vector hole_idx_arr, bool put_if_absent)
+ : SQLInsertRow(std::move(table_info), std::move(schema), std::move(default_map), default_str_length,
+ put_if_absent) {
hole_idx_arr_ = std::move(hole_idx_arr);
}
diff --git a/src/sdk/sql_insert_row.h b/src/sdk/sql_insert_row.h
index ded1c824e19..af18891587f 100644
--- a/src/sdk/sql_insert_row.h
+++ b/src/sdk/sql_insert_row.h
@@ -103,12 +103,13 @@ class DefaultValueContainer {
class SQLInsertRow {
public:
+ // for raw insert sql(no hole)
SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info,
std::shared_ptr schema, DefaultValueMap default_map,
- uint32_t default_str_length);
+ uint32_t default_str_length, bool put_if_absent);
SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info,
std::shared_ptr schema, DefaultValueMap default_map,
- uint32_t default_str_length, std::vector hole_idx_arr);
+ uint32_t default_str_length, std::vector hole_idx_arr, bool put_if_absent);
~SQLInsertRow() = default;
bool Init(int str_length);
bool AppendBool(bool val);
@@ -155,6 +156,10 @@ class SQLInsertRow {
return *table_info_;
}
+ bool IsPutIfAbsent() const {
+ return put_if_absent_;
+ }
+
private:
bool MakeDefault();
void PackDimension(const std::string& val);
@@ -175,13 +180,14 @@ class SQLInsertRow {
::openmldb::codec::RowBuilder rb_;
std::string val_;
uint32_t str_size_;
+ bool put_if_absent_;
};
class SQLInsertRows {
public:
SQLInsertRows(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info,
std::shared_ptr schema, DefaultValueMap default_map, uint32_t str_size,
- const std::vector& hole_idx_arr);
+ const std::vector& hole_idx_arr, bool put_if_absent);
~SQLInsertRows() = default;
std::shared_ptr NewRow();
inline uint32_t GetCnt() { return rows_.size(); }
@@ -200,6 +206,7 @@ class SQLInsertRows {
DefaultValueMap default_map_;
uint32_t default_str_length_;
std::vector hole_idx_arr_;
+ bool put_if_absent_;
std::vector> rows_;
};
diff --git a/src/sdk/sql_router.h b/src/sdk/sql_router.h
index 4317d435f8c..07b2e3b7734 100644
--- a/src/sdk/sql_router.h
+++ b/src/sdk/sql_router.h
@@ -130,7 +130,7 @@ class SQLRouter {
virtual bool ExecuteInsert(const std::string& db, const std::string& name, int tid, int partition_num,
hybridse::sdk::ByteArrayPtr dimension, int dimension_len,
- hybridse::sdk::ByteArrayPtr value, int len, hybridse::sdk::Status* status) = 0;
+ hybridse::sdk::ByteArrayPtr value, int len, bool put_if_absent, hybridse::sdk::Status* status) = 0;
virtual bool ExecuteDelete(std::shared_ptr row, hybridse::sdk::Status* status) = 0;
diff --git a/src/storage/aggregator.cc b/src/storage/aggregator.cc
index 7814c687be5..4615c87bc20 100644
--- a/src/storage/aggregator.cc
+++ b/src/storage/aggregator.cc
@@ -641,9 +641,9 @@ bool Aggregator::FlushAggrBuffer(const std::string& key, const std::string& filt
auto dimension = entry.add_dimensions();
dimension->set_idx(aggr_index_pos_);
dimension->set_key(key);
- bool ok = aggr_table_->Put(time, entry.value(), entry.dimensions());
- if (!ok) {
- PDLOG(ERROR, "Aggregator put failed");
+ auto st = aggr_table_->Put(time, entry.value(), entry.dimensions());
+ if (!st.ok()) {
+ LOG(ERROR) << "Aggregator put failed: " << st.ToString();
return false;
}
entry.set_pk(key);
diff --git a/src/storage/disk_table.cc b/src/storage/disk_table.cc
index b41c9f8fd3c..af35ab9a170 100644
--- a/src/storage/disk_table.cc
+++ b/src/storage/disk_table.cc
@@ -227,7 +227,8 @@ bool DiskTable::Put(const std::string& pk, uint64_t time, const char* data, uint
}
}
-bool DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions) {
+absl::Status DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions, bool put_if_absent) {
+ // disk table will update if key-time is the same, so no need to handle put_if_absent
const int8_t* data = reinterpret_cast(value.data());
std::string uncompress_data;
if (GetCompressType() == openmldb::type::kSnappy) {
@@ -237,15 +238,14 @@ bool DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& d
uint8_t version = codec::RowView::GetSchemaVersion(data);
auto decoder = GetVersionDecoder(version);
if (decoder == nullptr) {
- PDLOG(WARNING, "invalid schema version %u, tid %u pid %u", version, id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid schema version ", version));
}
rocksdb::WriteBatch batch;
for (auto it = dimensions.begin(); it != dimensions.end(); ++it) {
auto index_def = table_index_.GetIndex(it->idx());
if (!index_def || !index_def->IsReady()) {
- PDLOG(WARNING, "failed putting key %s to dimension %u in table tid %u pid %u", it->key().c_str(),
- it->idx(), id_, pid_);
+ PDLOG(WARNING, "failed putting key %s to dimension %u in table tid %u pid %u", it->key().c_str(), it->idx(),
+ id_, pid_);
}
int32_t inner_pos = table_index_.GetInnerIndexPos(it->idx());
auto inner_index = table_index_.GetInnerIndex(inner_pos);
@@ -256,12 +256,10 @@ bool DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& d
if (ts_col->IsAutoGenTs()) {
ts = time;
} else if (decoder->GetInteger(data, ts_col->GetId(), ts_col->GetType(), &ts) != 0) {
- PDLOG(WARNING, "get ts failed. tid %u pid %u", id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": get ts failed"));
}
if (ts < 0) {
- PDLOG(WARNING, "ts %ld is negative. tid %u pid %u", ts, id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": ts is negative ", ts));
}
if (inner_index->GetIndex().size() > 1) {
combine_key = CombineKeyTs(it->key(), ts, ts_col->GetId());
@@ -275,10 +273,9 @@ bool DiskTable::Put(uint64_t time, const std::string& value, const Dimensions& d
auto s = db_->Write(write_opts_, &batch);
if (s.ok()) {
offset_.fetch_add(1, std::memory_order_relaxed);
- return true;
+ return absl::OkStatus();
} else {
- DEBUGLOG("Put failed. tid %u pid %u msg %s", id_, pid_, s.ToString().c_str());
- return false;
+ return absl::InternalError(absl::StrCat(id_, ".", pid_, ": ", s.ToString()));
}
}
diff --git a/src/storage/disk_table.h b/src/storage/disk_table.h
index be549d0c2cd..7b471bac45e 100644
--- a/src/storage/disk_table.h
+++ b/src/storage/disk_table.h
@@ -21,6 +21,7 @@
#include
#include
#include
+
#include "base/slice.h"
#include "base/status.h"
#include "common/timer.h"
@@ -102,7 +103,7 @@ class AbsoluteTTLCompactionFilter : public rocksdb::CompactionFilter {
return false;
}
uint32_t ts_idx = *((uint32_t*)(key.data() + key.size() - TS_LEN - // NOLINT
- TS_POS_LEN));
+ TS_POS_LEN));
bool has_found = false;
for (const auto& index : indexs) {
auto ts_col = index->GetTsColumn();
@@ -110,7 +111,7 @@ class AbsoluteTTLCompactionFilter : public rocksdb::CompactionFilter {
return false;
}
if (ts_col->GetId() == ts_idx &&
- index->GetTTL()->ttl_type == openmldb::storage::TTLType::kAbsoluteTime) {
+ index->GetTTL()->ttl_type == openmldb::storage::TTLType::kAbsoluteTime) {
real_ttl = index->GetTTL()->abs_ttl;
has_found = true;
break;
@@ -172,7 +173,8 @@ class DiskTable : public Table {
bool Put(const std::string& pk, uint64_t time, const char* data, uint32_t size) override;
- bool Put(uint64_t time, const std::string& value, const Dimensions& dimensions) override;
+ absl::Status Put(uint64_t time, const std::string& value, const Dimensions& dimensions,
+ bool put_if_absent) override;
bool Get(uint32_t idx, const std::string& pk, uint64_t ts,
std::string& value); // NOLINT
@@ -183,8 +185,8 @@ class DiskTable : public Table {
base::Status Truncate();
- bool Delete(uint32_t idx, const std::string& pk,
- const std::optional& start_ts, const std::optional& end_ts) override;
+ bool Delete(uint32_t idx, const std::string& pk, const std::optional& start_ts,
+ const std::optional& end_ts) override;
uint64_t GetExpireTime(const TTLSt& ttl_st) override;
@@ -233,7 +235,7 @@ class DiskTable : public Table {
uint64_t GetRecordByteSize() const override { return 0; }
uint64_t GetRecordIdxByteSize() override;
- int GetCount(uint32_t index, const std::string& pk, uint64_t& count) override; // NOLINT
+ int GetCount(uint32_t index, const std::string& pk, uint64_t& count) override; // NOLINT
private:
base::Status Delete(uint32_t idx, const std::string& pk, uint64_t start_ts, const std::optional& end_ts);
diff --git a/src/storage/disk_table_test.cc b/src/storage/disk_table_test.cc
index 2a4e0d53c98..04a5d6edbb3 100644
--- a/src/storage/disk_table_test.cc
+++ b/src/storage/disk_table_test.cc
@@ -111,7 +111,7 @@ TEST_F(DiskTableTest, MultiDimensionPut) {
mapping.insert(std::make_pair("idx1", 1));
mapping.insert(std::make_pair("idx2", 2));
std::string table_path = FLAGS_hdd_root_path + "/2_1";
- DiskTable* table = new DiskTable("yjtable2", 2, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime,
+ Table* table = new DiskTable("yjtable2", 2, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime,
::openmldb::common::StorageMode::kHDD, table_path);
ASSERT_TRUE(table->Init());
ASSERT_EQ(3, (int64_t)table->GetIdxCnt());
@@ -136,7 +136,7 @@ TEST_F(DiskTableTest, MultiDimensionPut) {
d2->set_idx(2);
std::string value;
ASSERT_EQ(0, sdk_codec.EncodeRow(row, &value));
- bool ok = table->Put(1, value, dimensions);
+ bool ok = table->Put(1, value, dimensions).ok();
ASSERT_TRUE(ok);
// some functions in disk table need to be implemented.
// refer to issue #1238
@@ -202,7 +202,7 @@ TEST_F(DiskTableTest, MultiDimensionPut) {
row = {"valuea", "valueb", "valuec"};
ASSERT_EQ(0, sdk_codec.EncodeRow(row, &value));
- ASSERT_TRUE(table->Put(2, value, dimensions));
+ ASSERT_TRUE(table->Put(2, value, dimensions).ok());
it = table->NewIterator(0, "key2", ticket);
it->SeekToFirst();
@@ -223,7 +223,7 @@ TEST_F(DiskTableTest, MultiDimensionPut) {
delete it;
std::string val;
- ASSERT_TRUE(table->Get(1, "key1", 2, val));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(1, "key1", 2, val));
data = reinterpret_cast(val.data());
version = codec::RowView::GetSchemaVersion(data);
decoder = table->GetVersionDecoder(version);
@@ -277,7 +277,7 @@ TEST_F(DiskTableTest, LongPut) {
mapping.insert(std::make_pair("idx0", 0));
mapping.insert(std::make_pair("idx1", 1));
std::string table_path = FLAGS_ssd_root_path + "/3_1";
- DiskTable* table = new DiskTable("yjtable3", 3, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime,
+ Table* table = new DiskTable("yjtable3", 3, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime,
::openmldb::common::StorageMode::kSSD, table_path);
auto meta = ::openmldb::test::GetTableMeta({"idx0", "idx1"});
::openmldb::codec::SDKCodec sdk_codec(meta);
@@ -297,7 +297,7 @@ TEST_F(DiskTableTest, LongPut) {
std::string value;
ASSERT_EQ(0, sdk_codec.EncodeRow(row, &value));
for (int k = 0; k < 10; k++) {
- ASSERT_TRUE(table->Put(ts + k, value, dimensions));
+ ASSERT_TRUE(table->Put(ts + k, value, dimensions).ok());
}
}
for (int idx = 0; idx < 10; idx++) {
@@ -465,7 +465,7 @@ TEST_F(DiskTableTest, TraverseIterator) {
}
ASSERT_EQ(20, count);
std::string val;
- ASSERT_TRUE(table->Get(0, "test98", 9548, val));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(0, "test98", 9548, val));
ASSERT_EQ("valu8", val);
delete it;
delete table;
@@ -733,7 +733,7 @@ TEST_F(DiskTableTest, CompactFilter) {
std::map mapping;
mapping.insert(std::make_pair("idx0", 0));
std::string table_path = FLAGS_hdd_root_path + "/10_1";
- DiskTable* table = new DiskTable("t1", 10, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime,
+ Table* table = new DiskTable("t1", 10, 1, mapping, 10, ::openmldb::type::TTLType::kAbsoluteTime,
::openmldb::common::StorageMode::kHDD, table_path);
ASSERT_TRUE(table->Init());
uint64_t cur_time = ::baidu::common::timer::get_micros() / 1000;
@@ -754,24 +754,24 @@ TEST_F(DiskTableTest, CompactFilter) {
for (int k = 0; k < 5; k++) {
std::string value;
if (k > 2) {
- ASSERT_TRUE(table->Get(key, ts - k - 10 * 60 * 1000, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts - k - 10 * 60 * 1000, value));
ASSERT_EQ("value9", value);
} else {
- ASSERT_TRUE(table->Get(key, ts - k, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts - k, value));
ASSERT_EQ("value", value);
}
}
}
- table->CompactDB();
+ reinterpret_cast(table)->CompactDB();
for (int idx = 0; idx < 100; idx++) {
std::string key = "test" + std::to_string(idx);
uint64_t ts = cur_time;
for (int k = 0; k < 5; k++) {
std::string value;
if (k > 2) {
- ASSERT_FALSE(table->Get(key, ts - k - 10 * 60 * 1000, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(key, ts - k - 10 * 60 * 1000, value));
} else {
- ASSERT_TRUE(table->Get(key, ts - k, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts - k, value));
ASSERT_EQ("value", value);
}
}
@@ -794,7 +794,7 @@ TEST_F(DiskTableTest, CompactFilterMulTs) {
SchemaCodec::SetIndex(table_meta.add_column_key(), "mcc", "mcc", "ts2", ::openmldb::type::kAbsoluteTime, 5, 0);
std::string table_path = FLAGS_hdd_root_path + "/11_1";
- DiskTable* table = new DiskTable(table_meta, table_path);
+ Table* table = new DiskTable(table_meta, table_path);
ASSERT_TRUE(table->Init());
codec::SDKCodec codec(table_meta);
@@ -818,7 +818,7 @@ TEST_F(DiskTableTest, CompactFilterMulTs) {
std::to_string(cur_time - i * 60 * 1000)};
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- ASSERT_TRUE(table->Put(cur_time - i * 60 * 1000, value, dims));
+ ASSERT_TRUE(table->Put(cur_time - i * 60 * 1000, value, dims).ok());
}
} else {
@@ -828,7 +828,7 @@ TEST_F(DiskTableTest, CompactFilterMulTs) {
std::to_string(cur_time - i)};
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- ASSERT_TRUE(table->Put(cur_time - i, value, dims));
+ ASSERT_TRUE(table->Put(cur_time - i, value, dims).ok());
}
}
}
@@ -860,11 +860,11 @@ TEST_F(DiskTableTest, CompactFilterMulTs) {
std::string e_value;
ASSERT_EQ(0, codec.EncodeRow(row, &e_value));
std::string value;
- ASSERT_TRUE(table->Get(0, key, ts - i * 60 * 1000, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, ts - i * 60 * 1000, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(1, key, ts - i * 60 * 1000, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, ts - i * 60 * 1000, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(2, key1, ts - i * 60 * 1000, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, ts - i * 60 * 1000, value));
}
} else {
@@ -874,15 +874,15 @@ TEST_F(DiskTableTest, CompactFilterMulTs) {
std::string e_value;
ASSERT_EQ(0, codec.EncodeRow(row, &e_value));
std::string value;
- ASSERT_TRUE(table->Get(0, key, ts - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, ts - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(1, key, ts - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, ts - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(2, key1, ts - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, ts - i, value));
}
}
}
- table->CompactDB();
+ reinterpret_cast(table)->CompactDB();
iter = table->NewIterator(0, "card0", ticket);
iter->SeekToFirst();
while (iter->Valid()) {
@@ -908,18 +908,18 @@ TEST_F(DiskTableTest, CompactFilterMulTs) {
ASSERT_EQ(0, codec.EncodeRow(row, &e_value));
std::string value;
if (i < 3) {
- ASSERT_TRUE(table->Get(0, key, cur_ts, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, cur_ts, value));
ASSERT_EQ(e_value, value);
} else {
- ASSERT_FALSE(table->Get(0, key, cur_ts, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_ts, value));
}
if (i < 5) {
- ASSERT_TRUE(table->Get(1, key, cur_ts, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, cur_ts, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(2, key1, cur_ts, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, cur_ts, value));
} else {
- ASSERT_FALSE(table->Get(1, key, cur_ts, value));
- ASSERT_FALSE(table->Get(2, key1, cur_ts, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(1, key, cur_ts, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(2, key1, cur_ts, value));
}
}
} else {
@@ -929,11 +929,11 @@ TEST_F(DiskTableTest, CompactFilterMulTs) {
std::string e_value;
ASSERT_EQ(0, codec.EncodeRow(row, &e_value));
std::string value;
- ASSERT_TRUE(table->Get(0, key, ts - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, ts - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(1, key, ts - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, ts - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(2, key1, ts - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, ts - i, value));
}
}
}
@@ -955,7 +955,8 @@ TEST_F(DiskTableTest, GcHeadMulTs) {
SchemaCodec::SetIndex(table_meta.add_column_key(), "mcc", "mcc", "ts2", ::openmldb::type::kLatestTime, 0, 5);
std::string table_path = FLAGS_hdd_root_path + "/12_1";
- DiskTable* table = new DiskTable(table_meta, table_path);
+ // Table base class doesn't have Get method, cast to DiskTable to call Get
+ Table* table = new DiskTable(table_meta, table_path);
ASSERT_TRUE(table->Init());
codec::SDKCodec codec(table_meta);
@@ -980,7 +981,7 @@ TEST_F(DiskTableTest, GcHeadMulTs) {
std::to_string(cur_time - i), std::to_string(cur_time - i)};
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- ASSERT_TRUE(table->Put(cur_time - i, value, dims));
+ ASSERT_TRUE(table->Put(cur_time - i, value, dims).ok());
}
}
Ticket ticket;
@@ -1006,15 +1007,15 @@ TEST_F(DiskTableTest, GcHeadMulTs) {
ASSERT_EQ(0, codec.EncodeRow(row, &e_value));
std::string value;
if (idx == 50 && i > 2) {
- ASSERT_FALSE(table->Get(0, key, cur_time - i, value));
- ASSERT_FALSE(table->Get(1, key, cur_time - i, value));
- ASSERT_FALSE(table->Get(2, key1, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(1, key, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value));
} else {
- ASSERT_TRUE(table->Get(0, key, cur_time - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, cur_time - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(1, key, cur_time - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, cur_time - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(2, key1, cur_time - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value));
}
}
}
@@ -1041,24 +1042,24 @@ TEST_F(DiskTableTest, GcHeadMulTs) {
ASSERT_EQ(0, codec.EncodeRow(row, &e_value));
std::string value;
if (idx == 50 && i > 2) {
- ASSERT_FALSE(table->Get(0, key, cur_time - i, value));
- ASSERT_FALSE(table->Get(1, key, cur_time - i, value));
- ASSERT_FALSE(table->Get(2, key1, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(1, key, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value));
} else if (i < 3) {
- ASSERT_TRUE(table->Get(0, key, cur_time - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(0, key, cur_time - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(1, key, cur_time - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, cur_time - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(2, key1, cur_time - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value));
} else if (i < 5) {
- ASSERT_FALSE(table->Get(0, key, cur_time - i, value));
- ASSERT_TRUE(table->Get(1, key, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_time - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(1, key, cur_time - i, value));
ASSERT_EQ(e_value, value);
- ASSERT_TRUE(table->Get(2, key1, cur_time - i, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value));
} else {
- ASSERT_FALSE(table->Get(0, key, cur_time - i, value));
- ASSERT_FALSE(table->Get(1, key, cur_time - i, value));
- ASSERT_FALSE(table->Get(2, key1, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(0, key, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(1, key, cur_time - i, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(2, key1, cur_time - i, value));
}
}
}
@@ -1089,7 +1090,7 @@ TEST_F(DiskTableTest, GcHead) {
uint64_t ts = 9537;
for (int k = 0; k < 5; k++) {
std::string value;
- ASSERT_TRUE(table->Get(key, ts + k, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts + k, value));
if (idx == 10 && k == 2) {
ASSERT_EQ("value8", value);
} else {
@@ -1104,9 +1105,9 @@ TEST_F(DiskTableTest, GcHead) {
for (int k = 0; k < 5; k++) {
std::string value;
if (k < 2) {
- ASSERT_FALSE(table->Get(key, ts + k, value));
+ ASSERT_FALSE(reinterpret_cast(table)->Get(key, ts + k, value));
} else {
- ASSERT_TRUE(table->Get(key, ts + k, value));
+ ASSERT_TRUE(reinterpret_cast(table)->Get(key, ts + k, value));
if (idx == 10 && k == 2) {
ASSERT_EQ("value8", value);
} else {
diff --git a/src/storage/key_entry.h b/src/storage/key_entry.h
index e8969c9b832..1b5f4778f4f 100644
--- a/src/storage/key_entry.h
+++ b/src/storage/key_entry.h
@@ -49,11 +49,19 @@ struct DataBlock {
delete[] data;
data = nullptr;
}
+
+ bool EqualWithoutCnt(const DataBlock& other) const {
+ if (size != other.size) {
+ return false;
+ }
+ // you can improve it ref RowBuilder::InitBuffer header version
+ return memcmp(data, other.data, size) == 0;
+ }
};
// the desc time comparator
struct TimeComparator {
- int operator() (uint64_t a, uint64_t b) const {
+ int operator()(uint64_t a, uint64_t b) const {
if (a > b) {
return -1;
} else if (a == b) {
@@ -86,7 +94,6 @@ class KeyEntry {
std::atomic count_;
};
-
} // namespace storage
} // namespace openmldb
diff --git a/src/storage/mem_table.cc b/src/storage/mem_table.cc
index db7578619b7..3a57ffc4e93 100644
--- a/src/storage/mem_table.cc
+++ b/src/storage/mem_table.cc
@@ -140,21 +140,21 @@ bool MemTable::Put(const std::string& pk, uint64_t time, const char* data, uint3
return true;
}
-bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions) {
+absl::Status MemTable::Put(uint64_t time, const std::string& value, const Dimensions& dimensions, bool put_if_absent) {
if (dimensions.empty()) {
PDLOG(WARNING, "empty dimension. tid %u pid %u", id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": empty dimension"));
}
if (value.length() < codec::HEADER_LENGTH) {
PDLOG(WARNING, "invalid value. tid %u pid %u", id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid value"));
}
+ // inner index pos: -1 means invalid, so it's positive in inner_index_key_map
std::map inner_index_key_map;
for (auto iter = dimensions.begin(); iter != dimensions.end(); iter++) {
int32_t inner_pos = table_index_.GetInnerIndexPos(iter->idx());
if (inner_pos < 0) {
- PDLOG(WARNING, "invalid dimension. dimension idx %u, tid %u pid %u", iter->idx(), id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid dimension idx ", iter->idx()));
}
inner_index_key_map.emplace(inner_pos, iter->key());
}
@@ -168,15 +168,13 @@ bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& di
uint8_t version = codec::RowView::GetSchemaVersion(data);
auto decoder = GetVersionDecoder(version);
if (decoder == nullptr) {
- PDLOG(WARNING, "invalid schema version %u, tid %u pid %u", version, id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid schema version ", version));
}
std::map> ts_value_map;
for (const auto& kv : inner_index_key_map) {
auto inner_index = table_index_.GetInnerIndex(kv.first);
if (!inner_index) {
- PDLOG(WARNING, "invalid inner index pos %d. tid %u pid %u", kv.first, id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": invalid inner index pos ", kv.first));
}
std::map ts_map;
for (const auto& index_def : inner_index->GetIndex()) {
@@ -189,13 +187,12 @@ bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& di
if (ts_col->IsAutoGenTs()) {
ts = time;
} else if (decoder->GetInteger(data, ts_col->GetId(), ts_col->GetType(), &ts) != 0) {
- PDLOG(WARNING, "get ts failed. tid %u pid %u", id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": get ts failed"));
}
if (ts < 0) {
- PDLOG(WARNING, "ts %ld is negative. tid %u pid %u", ts, id_, pid_);
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": ts is negative ", ts));
}
+ // TODO(hw): why uint32_t to int32_t?
ts_map.emplace(ts_col->GetId(), ts);
real_ref_cnt++;
}
@@ -205,7 +202,7 @@ bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& di
}
}
if (ts_value_map.empty()) {
- return false;
+ return absl::InvalidArgumentError(absl::StrCat(id_, ".", pid_, ": empty ts value map"));
}
auto* block = new DataBlock(real_ref_cnt, value.c_str(), value.length());
for (const auto& kv : inner_index_key_map) {
@@ -218,10 +215,12 @@ bool MemTable::Put(uint64_t time, const std::string& value, const Dimensions& di
seg_idx = ::openmldb::base::hash(kv.second.data(), kv.second.size(), SEED) % seg_cnt_;
}
Segment* segment = segments_[kv.first][seg_idx];
- segment->Put(::openmldb::base::Slice(kv.second), iter->second, block);
+ if (!segment->Put(kv.second, iter->second, block, put_if_absent)) {
+ return absl::AlreadyExistsError("data exists"); // let caller know exists
+ }
}
record_byte_size_.fetch_add(GetRecordSize(value.length()));
- return true;
+ return absl::OkStatus();
}
bool MemTable::Delete(const ::openmldb::api::LogEntry& entry) {
@@ -550,6 +549,7 @@ uint64_t MemTable::GetRecordIdxCnt() {
if (!index_def || !index_def->IsReady()) {
return record_idx_cnt;
}
+
uint32_t inner_idx = index_def->GetInnerPos();
auto inner_index = table_index_.GetInnerIndex(inner_idx);
int32_t ts_col_id = -1;
@@ -665,7 +665,7 @@ bool MemTable::AddIndex(const ::openmldb::common::ColumnKey& column_key) {
}
ts_vec.push_back(ts_iter->second.GetId());
} else {
- ts_vec.push_back(DEFUALT_TS_COL_ID);
+ ts_vec.push_back(DEFAULT_TS_COL_ID);
}
uint32_t inner_id = table_index_.GetAllInnerIndex()->size();
Segment** seg_arr = new Segment*[seg_cnt_];
@@ -685,7 +685,7 @@ bool MemTable::AddIndex(const ::openmldb::common::ColumnKey& column_key) {
auto ts_iter = schema.find(column_key.ts_name());
index_def->SetTsColumn(std::make_shared(ts_iter->second));
} else {
- index_def->SetTsColumn(std::make_shared(DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID,
+ index_def->SetTsColumn(std::make_shared(DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID,
::openmldb::type::kTimestamp, true));
}
if (column_key.has_ttl()) {
@@ -724,14 +724,14 @@ bool MemTable::DeleteIndex(const std::string& idx_name) {
new_table_meta->mutable_column_key(index_def->GetId())->set_flag(1);
}
std::atomic_store_explicit(&table_meta_, new_table_meta, std::memory_order_release);
- index_def->SetStatus(IndexStatus::kWaiting);
+ index_def->SetStatus(IndexStatus::kWaiting); // let gc do deletion
return true;
}
::hybridse::vm::WindowIterator* MemTable::NewWindowIterator(uint32_t index) {
std::shared_ptr index_def = table_index_.GetIndex(index);
if (!index_def || !index_def->IsReady()) {
- LOG(WARNING) << "index id " << index << " not found. tid " << id_ << " pid " << pid_;
+ LOG(WARNING) << "index id " << index << " not found. tid " << id_ << " pid " << pid_;
return nullptr;
}
uint64_t expire_time = 0;
diff --git a/src/storage/mem_table.h b/src/storage/mem_table.h
index 8ae1964e0ef..e85762a97dc 100644
--- a/src/storage/mem_table.h
+++ b/src/storage/mem_table.h
@@ -51,7 +51,8 @@ class MemTable : public Table {
bool Put(const std::string& pk, uint64_t time, const char* data, uint32_t size) override;
- bool Put(uint64_t time, const std::string& value, const Dimensions& dimensions) override;
+ absl::Status Put(uint64_t time, const std::string& value, const Dimensions& dimensions,
+ bool put_if_absent) override;
bool GetBulkLoadInfo(::openmldb::api::BulkLoadInfoResponse* response);
@@ -59,8 +60,8 @@ class MemTable : public Table {
const ::google::protobuf::RepeatedPtrField<::openmldb::api::BulkLoadIndex>& indexes);
bool Delete(const ::openmldb::api::LogEntry& entry) override;
- bool Delete(uint32_t idx, const std::string& key,
- const std::optional& start_ts, const std::optional& end_ts);
+ bool Delete(uint32_t idx, const std::string& key, const std::optional& start_ts,
+ const std::optional& end_ts);
// use the first demission
TableIterator* NewIterator(const std::string& pk, Ticket& ticket) override;
diff --git a/src/storage/schema.cc b/src/storage/schema.cc
index 7efff1d35a4..3250a047a8b 100644
--- a/src/storage/schema.cc
+++ b/src/storage/schema.cc
@@ -216,7 +216,7 @@ int TableIndex::ParseFromMeta(const ::openmldb::api::TableMeta& table_meta) {
index->SetTsColumn(col_map[ts_name]);
} else {
// set default ts col
- index->SetTsColumn(std::make_shared(DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID,
+ index->SetTsColumn(std::make_shared(DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID,
::openmldb::type::kTimestamp, true));
}
if (column_key.has_ttl()) {
@@ -232,7 +232,7 @@ int TableIndex::ParseFromMeta(const ::openmldb::api::TableMeta& table_meta) {
// add default dimension
if (indexs_->empty()) {
auto index = std::make_shared("idx0", 0);
- index->SetTsColumn(std::make_shared(DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID,
+ index->SetTsColumn(std::make_shared(DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID,
::openmldb::type::kTimestamp, true));
if (AddIndex(index) < 0) {
DLOG(WARNING) << "add index failed";
diff --git a/src/storage/schema.h b/src/storage/schema.h
index 2143744122e..39359761ed9 100644
--- a/src/storage/schema.h
+++ b/src/storage/schema.h
@@ -31,8 +31,8 @@
namespace openmldb::storage {
static constexpr uint32_t MAX_INDEX_NUM = 200;
-static constexpr uint32_t DEFUALT_TS_COL_ID = UINT32_MAX;
-static constexpr const char* DEFUALT_TS_COL_NAME = "default_ts";
+static constexpr uint32_t DEFAULT_TS_COL_ID = UINT32_MAX;
+static constexpr const char* DEFAULT_TS_COL_NAME = "default_ts";
enum TTLType { kAbsoluteTime = 1, kRelativeTime = 2, kLatestTime = 3, kAbsAndLat = 4, kAbsOrLat = 5 };
@@ -163,7 +163,7 @@ class ColumnDef {
return false;
}
- inline bool IsAutoGenTs() const { return id_ == DEFUALT_TS_COL_ID; }
+ inline bool IsAutoGenTs() const { return id_ == DEFAULT_TS_COL_ID; }
private:
std::string name_;
diff --git a/src/storage/schema_test.cc b/src/storage/schema_test.cc
index 1c169697634..77840c13e93 100644
--- a/src/storage/schema_test.cc
+++ b/src/storage/schema_test.cc
@@ -233,9 +233,9 @@ TEST_F(SchemaTest, TsAndDefaultTs) {
::openmldb::storage::kAbsoluteTime);
AssertIndex(*(table_index.GetIndex("key2")), "key2", "col1", "col7", 7, 10, 0, ::openmldb::storage::kAbsoluteTime);
AssertIndex(*(table_index.GetIndex("key3")), "key3", "col2", "col6", 6, 10, 0, ::openmldb::storage::kAbsoluteTime);
- AssertIndex(*(table_index.GetIndex("key4")), "key4", "col2", DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID,
+ AssertIndex(*(table_index.GetIndex("key4")), "key4", "col2", DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID,
10, 0, ::openmldb::storage::kAbsoluteTime);
- AssertIndex(*(table_index.GetIndex("key5")), "key5", "col3", DEFUALT_TS_COL_NAME, DEFUALT_TS_COL_ID,
+ AssertIndex(*(table_index.GetIndex("key5")), "key5", "col3", DEFAULT_TS_COL_NAME, DEFAULT_TS_COL_ID,
10, 0, ::openmldb::storage::kAbsoluteTime);
auto inner_index = table_index.GetAllInnerIndex();
ASSERT_EQ(inner_index->size(), 3u);
@@ -243,10 +243,10 @@ TEST_F(SchemaTest, TsAndDefaultTs) {
std::vector ts_vec0 = {6, 7};
AssertInnerIndex(*(table_index.GetInnerIndex(0)), 0, index0, ts_vec0);
std::vector index1 = {"key3", "key4"};
- std::vector ts_vec1 = {6, DEFUALT_TS_COL_ID};
+ std::vector ts_vec1 = {6, DEFAULT_TS_COL_ID};
AssertInnerIndex(*(table_index.GetInnerIndex(1)), 1, index1, ts_vec1);
std::vector index2 = {"key5"};
- std::vector ts_vec2 = {DEFUALT_TS_COL_ID};
+ std::vector ts_vec2 = {DEFAULT_TS_COL_ID};
AssertInnerIndex(*(table_index.GetInnerIndex(2)), 2, index2, ts_vec2);
}
diff --git a/src/storage/segment.cc b/src/storage/segment.cc
index d79b6e85681..8255d27b7bd 100644
--- a/src/storage/segment.cc
+++ b/src/storage/segment.cc
@@ -15,7 +15,9 @@
*/
#include "storage/segment.h"
+
#include
+
#include
#include "base/glog_wrapper.h"
@@ -64,9 +66,7 @@ Segment::Segment(uint8_t height, const std::vector& ts_idx_vec)
}
}
-Segment::~Segment() {
- delete entries_;
-}
+Segment::~Segment() { delete entries_; }
void Segment::Release(StatisticsInfo* statistics_info) {
std::unique_ptr it(entries_->NewIterator());
@@ -98,9 +98,7 @@ void Segment::Release(StatisticsInfo* statistics_info) {
}
}
-void Segment::ReleaseAndCount(StatisticsInfo* statistics_info) {
- Release(statistics_info);
-}
+void Segment::ReleaseAndCount(StatisticsInfo* statistics_info) { Release(statistics_info); }
void Segment::ReleaseAndCount(const std::vector& id_vec, StatisticsInfo* statistics_info) {
if (ts_cnt_ <= 1) {
@@ -135,25 +133,28 @@ void Segment::ReleaseAndCount(const std::vector& id_vec, StatisticsInfo*
}
}
-void Segment::Put(const Slice& key, uint64_t time, const char* data, uint32_t size) {
+void Segment::Put(const Slice& key, uint64_t time, const char* data, uint32_t size, bool put_if_absent,
+ bool check_all_time) {
if (ts_cnt_ > 1) {
return;
}
auto* db = new DataBlock(1, data, size);
- Put(key, time, db);
+ Put(key, time, db, put_if_absent, check_all_time);
}
-void Segment::Put(const Slice& key, uint64_t time, DataBlock* row) {
+bool Segment::Put(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent, bool check_all_time) {
if (ts_cnt_ > 1) {
- return;
+ LOG(ERROR) << "wrong call";
+ return false;
}
std::lock_guard lock(mu_);
- PutUnlock(key, time, row);
+ return PutUnlock(key, time, row, put_if_absent, check_all_time);
}
-void Segment::PutUnlock(const Slice& key, uint64_t time, DataBlock* row) {
+bool Segment::PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent, bool check_all_time) {
void* entry = nullptr;
uint32_t byte_size = 0;
+ // one key just one entry
int ret = entries_->Get(key, entry);
if (ret < 0 || entry == nullptr) {
char* pk = new char[key.size()];
@@ -164,12 +165,17 @@ void Segment::PutUnlock(const Slice& key, uint64_t time, DataBlock* row) {
uint8_t height = entries_->Insert(skey, entry);
byte_size += GetRecordPkIdxSize(height, key.size(), key_entry_max_height_);
pk_cnt_.fetch_add(1, std::memory_order_relaxed);
+ // no need to check if absent when first put
+ } else if (put_if_absent && ListContains(reinterpret_cast(entry), time, row, check_all_time)) {
+ return false;
}
+
idx_cnt_vec_[0]->fetch_add(1, std::memory_order_relaxed);
uint8_t height = reinterpret_cast(entry)->entries.Insert(time, row);
reinterpret_cast(entry)->count_.fetch_add(1, std::memory_order_relaxed);
byte_size += GetRecordTsIdxSize(height);
idx_byte_size_.fetch_add(byte_size, std::memory_order_relaxed);
+ return true;
}
void Segment::BulkLoadPut(unsigned int key_entry_id, const Slice& key, uint64_t time, DataBlock* row) {
@@ -201,16 +207,17 @@ void Segment::BulkLoadPut(unsigned int key_entry_id, const Slice& key, uint64_t
}
}
-void Segment::Put(const Slice& key, const std::map& ts_map, DataBlock* row) {
- uint32_t ts_size = ts_map.size();
- if (ts_size == 0) {
- return;
+bool Segment::Put(const Slice& key, const std::map& ts_map, DataBlock* row, bool put_if_absent) {
+ if (ts_map.empty()) {
+ return false;
}
if (ts_cnt_ == 1) {
+ bool ret = false;
if (auto pos = ts_map.find(ts_idx_map_.begin()->first); pos != ts_map.end()) {
- Put(key, pos->second, row);
+ // TODO(hw): why ts_map key is int32_t, default ts is uint32_t?
+ ret = Put(key, pos->second, row, put_if_absent, pos->first == DEFAULT_TS_COL_ID);
}
- return;
+ return ret;
}
void* entry_arr = nullptr;
std::lock_guard lock(mu_);
@@ -237,12 +244,16 @@ void Segment::Put(const Slice& key, const std::map& ts_map, D
}
}
auto entry = reinterpret_cast(entry_arr)[pos->second];
+ if (put_if_absent && ListContains(entry, kv.second, row, pos->first == DEFAULT_TS_COL_ID)) {
+ return false;
+ }
uint8_t height = entry->entries.Insert(kv.second, row);
entry->count_.fetch_add(1, std::memory_order_relaxed);
byte_size += GetRecordTsIdxSize(height);
idx_byte_size_.fetch_add(byte_size, std::memory_order_relaxed);
idx_cnt_vec_[pos->second]->fetch_add(1, std::memory_order_relaxed);
}
+ return true;
}
bool Segment::Delete(const std::optional& idx, const Slice& key) {
@@ -289,8 +300,8 @@ bool Segment::Delete(const std::optional& idx, const Slice& key) {
return true;
}
-bool Segment::Delete(const std::optional& idx, const Slice& key,
- uint64_t ts, const std::optional& end_ts) {
+bool Segment::Delete(const std::optional& idx, const Slice& key, uint64_t ts,
+ const std::optional& end_ts) {
void* entry = nullptr;
if (entries_->Get(key, entry) < 0 || entry == nullptr) {
return true;
@@ -347,7 +358,7 @@ bool Segment::Delete(const std::optional& idx, const Slice& key,
}
void Segment::FreeList(uint32_t ts_idx, ::openmldb::base::Node* node,
- StatisticsInfo* statistics_info) {
+ StatisticsInfo* statistics_info) {
while (node != nullptr) {
statistics_info->IncrIdxCnt(ts_idx);
::openmldb::base::Node* tmp = node;
@@ -365,7 +376,6 @@ void Segment::FreeList(uint32_t ts_idx, ::openmldb::base::Node it(entry->entries.NewIterator());
+ if (check_all_time) {
+ it->SeekToFirst();
+ while (it->Valid()) {
+ if (it->GetValue()->EqualWithoutCnt(*row)) {
+ return true;
+ }
+ it->Next();
+ }
+ } else {
+ // less than but desc time comparator, so it's <= time(not valid if empty or all > time), and get smaller by
+ // next
+ it->Seek(time);
+ while (it->Valid()) {
+ // key > time is just a protection, normally it should not happen
+ if (it->GetKey() < time || it->GetKey() > time) {
+ break; // no entry == time, or all entries == time have been checked
+ }
+ if (it->GetValue()->EqualWithoutCnt(*row)) {
+ return true;
+ }
+ it->Next();
+ }
+ }
+ return false;
+}
+
// fast gc with no global pause
void Segment::Gc4TTL(const uint64_t time, StatisticsInfo* statistics_info) {
uint64_t consumed = ::baidu::common::timer::get_micros();
@@ -606,8 +645,7 @@ void Segment::Gc4TTL(const uint64_t time, StatisticsInfo* statistics_info) {
if (node == nullptr) {
continue;
} else if (node->GetKey() > time) {
- DEBUGLOG("[Gc4TTL] segment gc with key %lu need not ttl, last node key %lu",
- time, node->GetKey());
+ DEBUGLOG("[Gc4TTL] segment gc with key %lu need not ttl, last node key %lu", time, node->GetKey());
continue;
}
node = nullptr;
@@ -648,8 +686,7 @@ void Segment::Gc4TTLAndHead(const uint64_t time, const uint64_t keep_cnt, Statis
if (node == nullptr) {
continue;
} else if (node->GetKey() > time) {
- DEBUGLOG("[Gc4TTLAndHead] segment gc with key %lu need not ttl, last node key %lu",
- time, node->GetKey());
+ DEBUGLOG("[Gc4TTLAndHead] segment gc with key %lu need not ttl, last node key %lu", time, node->GetKey());
continue;
}
node = nullptr;
@@ -663,8 +700,8 @@ void Segment::Gc4TTLAndHead(const uint64_t time, const uint64_t keep_cnt, Statis
FreeList(0, node, statistics_info);
entry->count_.fetch_sub(statistics_info->GetIdxCnt(0) - cur_idx_cnt, std::memory_order_relaxed);
}
- DEBUGLOG("[Gc4TTLAndHead] segment gc time %lu and keep cnt %lu consumed %lu, count %lu",
- time, keep_cnt, (::baidu::common::timer::get_micros() - consumed) / 1000, statistics_info->GetIdxCnt(0) - old);
+ DEBUGLOG("[Gc4TTLAndHead] segment gc time %lu and keep cnt %lu consumed %lu, count %lu", time, keep_cnt,
+ (::baidu::common::timer::get_micros() - consumed) / 1000, statistics_info->GetIdxCnt(0) - old);
idx_cnt_vec_[0]->fetch_sub(statistics_info->GetIdxCnt(0) - old, std::memory_order_relaxed);
}
@@ -709,8 +746,8 @@ void Segment::Gc4TTLOrHead(const uint64_t time, const uint64_t keep_cnt, Statist
FreeList(0, node, statistics_info);
entry->count_.fetch_sub(statistics_info->GetIdxCnt(0) - cur_idx_cnt, std::memory_order_relaxed);
}
- DEBUGLOG("[Gc4TTLAndHead] segment gc time %lu and keep cnt %lu consumed %lu, count %lu",
- time, keep_cnt, (::baidu::common::timer::get_micros() - consumed) / 1000, statistics_info->GetIdxCnt(0) - old);
+ DEBUGLOG("[Gc4TTLAndHead] segment gc time %lu and keep cnt %lu consumed %lu, count %lu", time, keep_cnt,
+ (::baidu::common::timer::get_micros() - consumed) / 1000, statistics_info->GetIdxCnt(0) - old);
idx_cnt_vec_[0]->fetch_sub(statistics_info->GetIdxCnt(0) - old, std::memory_order_relaxed);
}
@@ -754,8 +791,8 @@ MemTableIterator* Segment::NewIterator(const Slice& key, Ticket& ticket, type::C
return new MemTableIterator(reinterpret_cast(entry)->entries.NewIterator(), compress_type);
}
-MemTableIterator* Segment::NewIterator(const Slice& key, uint32_t idx,
- Ticket& ticket, type::CompressType compress_type) {
+MemTableIterator* Segment::NewIterator(const Slice& key, uint32_t idx, Ticket& ticket,
+ type::CompressType compress_type) {
auto pos = ts_idx_map_.find(idx);
if (pos == ts_idx_map_.end()) {
return new MemTableIterator(nullptr, compress_type);
diff --git a/src/storage/segment.h b/src/storage/segment.h
index fe58dd893a0..11322483832 100644
--- a/src/storage/segment.h
+++ b/src/storage/segment.h
@@ -70,20 +70,19 @@ class Segment {
Segment(uint8_t height, const std::vector& ts_idx_vec);
~Segment();
- // Put time data
- void Put(const Slice& key, uint64_t time, const char* data, uint32_t size);
+ // legacy interface called by memtable and ut
+ void Put(const Slice& key, uint64_t time, const char* data, uint32_t size, bool put_if_absent = false,
+ bool check_all_time = false);
- void Put(const Slice& key, uint64_t time, DataBlock* row);
-
- void PutUnlock(const Slice& key, uint64_t time, DataBlock* row);
+ bool Put(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent = false, bool check_all_time = false);
void BulkLoadPut(unsigned int key_entry_id, const Slice& key, uint64_t time, DataBlock* row);
-
- void Put(const Slice& key, const std::map& ts_map, DataBlock* row);
+ // main put method
+ bool Put(const Slice& key, const std::map& ts_map, DataBlock* row, bool put_if_absent = false);
bool Delete(const std::optional& idx, const Slice& key);
- bool Delete(const std::optional& idx, const Slice& key,
- uint64_t ts, const std::optional& end_ts);
+ bool Delete(const std::optional& idx, const Slice& key, uint64_t ts,
+ const std::optional& end_ts);
void Release(StatisticsInfo* statistics_info);
@@ -97,12 +96,10 @@ class Segment {
void GcAllType(const std::map& ttl_st_map, StatisticsInfo* statistics_info);
MemTableIterator* NewIterator(const Slice& key, Ticket& ticket, type::CompressType compress_type); // NOLINT
- MemTableIterator* NewIterator(const Slice& key, uint32_t idx,
- Ticket& ticket, type::CompressType compress_type); // NOLINT
+ MemTableIterator* NewIterator(const Slice& key, uint32_t idx, Ticket& ticket, // NOLINT
+ type::CompressType compress_type);
- uint64_t GetIdxCnt() const {
- return idx_cnt_vec_[0]->load(std::memory_order_relaxed);
- }
+ uint64_t GetIdxCnt() const { return idx_cnt_vec_[0]->load(std::memory_order_relaxed); }
int GetIdxCnt(uint32_t ts_idx, uint64_t& ts_cnt) { // NOLINT
uint32_t real_idx = 0;
@@ -145,10 +142,14 @@ class Segment {
void ReleaseAndCount(const std::vector& id_vec, StatisticsInfo* statistics_info);
private:
- void FreeList(uint32_t ts_idx, ::openmldb::base::Node* node,
- StatisticsInfo* statistics_info);
+ void FreeList(uint32_t ts_idx, ::openmldb::base::Node* node, StatisticsInfo* statistics_info);
void SplitList(KeyEntry* entry, uint64_t ts, ::openmldb::base::Node** node);
+ bool ListContains(KeyEntry* entry, uint64_t time, DataBlock* row, bool check_all_time);
+
+ bool PutUnlock(const Slice& key, uint64_t time, DataBlock* row, bool put_if_absent = false,
+ bool check_all_time = false);
+
private:
KeyEntries* entries_;
std::mutex mu_;
diff --git a/src/storage/segment_test.cc b/src/storage/segment_test.cc
index c51c0984473..e43461c47e6 100644
--- a/src/storage/segment_test.cc
+++ b/src/storage/segment_test.cc
@@ -424,6 +424,82 @@ TEST_F(SegmentTest, TestDeleteRange) {
CheckStatisticsInfo(CreateStatisticsInfo(20, 1012, 20 * (6 + sizeof(DataBlock))), gc_info);
}
+TEST_F(SegmentTest, PutIfAbsent) {
+ {
+ Segment segment(8); // so ts_cnt_ == 1
+ // check all time == false
+ segment.Put("PK", 1, "test1", 5, true);
+ segment.Put("PK", 1, "test2", 5, true); // even key&time is the same, different value means different record
+ ASSERT_EQ(2, (int64_t)segment.GetIdxCnt());
+ ASSERT_EQ(1, (int64_t)segment.GetPkCnt());
+ segment.Put("PK", 2, "test3", 5, true);
+ segment.Put("PK", 2, "test4", 5, true);
+ segment.Put("PK", 3, "test5", 5, true);
+ segment.Put("PK", 3, "test6", 5, true);
+ ASSERT_EQ(6, (int64_t)segment.GetIdxCnt());
+ // insert exists rows
+ segment.Put("PK", 2, "test3", 5, true);
+ segment.Put("PK", 1, "test1", 5, true);
+ segment.Put("PK", 1, "test2", 5, true);
+ segment.Put("PK", 3, "test6", 5, true);
+ ASSERT_EQ(6, (int64_t)segment.GetIdxCnt());
+ // new rows
+ segment.Put("PK", 2, "test7", 5, true);
+ ASSERT_EQ(7, (int64_t)segment.GetIdxCnt());
+ segment.Put("PK", 0, "test8", 5, true); // seek to last, next is empty
+ ASSERT_EQ(8, (int64_t)segment.GetIdxCnt());
+ }
+
+ {
+ // support when ts_cnt_ != 1 too
+ std::vector ts_idx_vec = {1, 3};
+ Segment segment(8, ts_idx_vec);
+ ASSERT_EQ(2, (int64_t)segment.GetTsCnt());
+ std::string key = "PK";
+ uint64_t ts = 1669013677221000;
+ // the same ts
+ for (int j = 0; j < 2; j++) {
+ DataBlock* data = new DataBlock(2, key.c_str(), key.length());
+ std::map ts_map = {{1, ts}, {3, ts}};
+ segment.Put(Slice(key), ts_map, data, true);
+ }
+ ASSERT_EQ(1, GetCount(&segment, 1));
+ ASSERT_EQ(1, GetCount(&segment, 3));
+ }
+
+ {
+ // put ts_map contains DEFAULT_TS_COL_ID
+ std::vector ts_idx_vec = {DEFAULT_TS_COL_ID};
+ Segment segment(8, ts_idx_vec);
+ ASSERT_EQ(1, (int64_t)segment.GetTsCnt());
+ std::string key = "PK";
+ std::map ts_map = {{DEFAULT_TS_COL_ID, 100}}; // cur time == 100
+ auto* block = new DataBlock(1, "test1", 5);
+ segment.Put(Slice(key), ts_map, block, true);
+ ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID));
+ ts_map = {{DEFAULT_TS_COL_ID, 200}};
+ block = new DataBlock(1, "test1", 5);
+ segment.Put(Slice(key), ts_map, block, true);
+ ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID));
+ }
+
+ {
+ // put ts_map contains DEFAULT_TS_COL_ID
+ std::vector ts_idx_vec = {DEFAULT_TS_COL_ID, 1, 3};
+ Segment segment(8, ts_idx_vec);
+ ASSERT_EQ(3, (int64_t)segment.GetTsCnt());
+ std::string key = "PK";
+ std::map ts_map = {{DEFAULT_TS_COL_ID, 100}}; // cur time == 100
+ auto* block = new DataBlock(1, "test1", 5);
+ segment.Put(Slice(key), ts_map, block, true);
+ ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID));
+ ts_map = {{DEFAULT_TS_COL_ID, 200}};
+ block = new DataBlock(1, "test1", 5);
+ segment.Put(Slice(key), ts_map, block, true);
+ ASSERT_EQ(1, GetCount(&segment, DEFAULT_TS_COL_ID));
+ }
+}
+
} // namespace storage
} // namespace openmldb
diff --git a/src/storage/snapshot_test.cc b/src/storage/snapshot_test.cc
index 910a8bc7724..e9dd679eafc 100644
--- a/src/storage/snapshot_test.cc
+++ b/src/storage/snapshot_test.cc
@@ -1085,7 +1085,7 @@ TEST_F(SnapshotTest, MakeSnapshotAbsOrLat) {
SchemaCodec::SetColumnDesc(table_meta->add_column_desc(), "value", ::openmldb::type::kString);
SchemaCodec::SetIndex(table_meta->add_column_key(), "index1", "card|merchant", "", ::openmldb::type::kAbsOrLat, 0,
1);
- std::shared_ptr table = std::make_shared(*table_meta);
+ std::shared_ptr table = std::make_shared(*table_meta);
table->Init();
LogParts* log_part = new LogParts(12, 4, scmp);
@@ -1119,7 +1119,7 @@ TEST_F(SnapshotTest, MakeSnapshotAbsOrLat) {
google::protobuf::RepeatedPtrField<::openmldb::api::Dimension> d_list;
::openmldb::api::Dimension* d_ptr2 = d_list.Add();
d_ptr2->CopyFrom(dimensions);
- ASSERT_EQ(table->Put(i + 1, *result, d_list), true);
+ ASSERT_EQ(table->Put(i + 1, *result, d_list).ok(), true);
}
table->SchedGc();
diff --git a/src/storage/table.h b/src/storage/table.h
index 0766e4cf6c4..4c4a1f011f7 100644
--- a/src/storage/table.h
+++ b/src/storage/table.h
@@ -22,6 +22,7 @@
#include
#include
+#include "absl/status/status.h"
#include "codec/codec.h"
#include "proto/tablet.pb.h"
#include "storage/iterator.h"
@@ -50,17 +51,16 @@ class Table {
int InitColumnDesc();
virtual bool Put(const std::string& pk, uint64_t time, const char* data, uint32_t size) = 0;
+ // DO NOT set different default value in derived class
+ virtual absl::Status Put(uint64_t time, const std::string& value, const Dimensions& dimensions,
+ bool put_if_absent = false) = 0;
- virtual bool Put(uint64_t time, const std::string& value, const Dimensions& dimensions) = 0;
-
- bool Put(const ::openmldb::api::LogEntry& entry) {
- return Put(entry.ts(), entry.value(), entry.dimensions());
- }
+ bool Put(const ::openmldb::api::LogEntry& entry) { return Put(entry.ts(), entry.value(), entry.dimensions()).ok(); }
virtual bool Delete(const ::openmldb::api::LogEntry& entry) = 0;
- virtual bool Delete(uint32_t idx, const std::string& key,
- const std::optional& start_ts, const std::optional& end_ts) = 0;
+ virtual bool Delete(uint32_t idx, const std::string& key, const std::optional& start_ts,
+ const std::optional& end_ts) = 0;
virtual TableIterator* NewIterator(const std::string& pk,
Ticket& ticket) = 0; // NOLINT
@@ -88,9 +88,7 @@ class Table {
}
return "";
}
- inline ::openmldb::common::StorageMode GetStorageMode() const {
- return storage_mode_;
- }
+ inline ::openmldb::common::StorageMode GetStorageMode() const { return storage_mode_; }
inline uint32_t GetId() const { return id_; }
inline uint32_t GetIdxCnt() const { return table_index_.Size(); }
@@ -173,7 +171,7 @@ class Table {
virtual uint64_t GetRecordByteSize() const = 0;
virtual uint64_t GetRecordIdxByteSize() = 0;
- virtual int GetCount(uint32_t index, const std::string& pk, uint64_t& count) = 0; // NOLINT
+ virtual int GetCount(uint32_t index, const std::string& pk, uint64_t& count) = 0; // NOLINT
protected:
void UpdateTTL();
diff --git a/src/storage/table_iterator_test.cc b/src/storage/table_iterator_test.cc
index 7ba932422e1..3af20940266 100644
--- a/src/storage/table_iterator_test.cc
+++ b/src/storage/table_iterator_test.cc
@@ -450,7 +450,7 @@ TEST_P(TableIteratorTest, SeekNonExistent) {
ASSERT_EQ(0, now - wit->GetKey());
}
-INSTANTIATE_TEST_CASE_P(TestMemAndHDD, TableIteratorTest,
+INSTANTIATE_TEST_SUITE_P(TestMemAndHDD, TableIteratorTest,
::testing::Values(::openmldb::common::kMemory, ::openmldb::common::kHDD));
} // namespace storage
diff --git a/src/storage/table_test.cc b/src/storage/table_test.cc
index 251e92986c6..43b3508822e 100644
--- a/src/storage/table_test.cc
+++ b/src/storage/table_test.cc
@@ -198,7 +198,7 @@ TEST_P(TableTest, MultiDimissionPut0) {
::openmldb::codec::SDKCodec sdk_codec(meta);
std::string result;
sdk_codec.EncodeRow({"d0", "d1", "d2"}, &result);
- bool ok = table->Put(1, result, dimensions);
+ bool ok = table->Put(1, result, dimensions).ok();
ASSERT_TRUE(ok);
// some functions in disk table need to be implemented.
// refer to issue #1238
@@ -808,7 +808,7 @@ TEST_P(TableTest, TableIteratorTS) {
dim->set_key(row[1]);
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
TableIterator* it = table->NewTraverseIterator(0);
it->SeekToFirst();
@@ -921,7 +921,7 @@ TEST_P(TableTest, TraverseIteratorCount) {
dim->set_key(row[1]);
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
TableIterator* it = table->NewTraverseIterator(0);
it->SeekToFirst();
@@ -1048,7 +1048,7 @@ TEST_P(TableTest, AbsAndLatSetGet) {
dim->set_key("mcc");
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
// test get and set ttl
ASSERT_EQ(10, (int64_t)table->GetIndex(0)->GetTTL()->abs_ttl / (10 * 6000));
@@ -1149,7 +1149,7 @@ TEST_P(TableTest, AbsOrLatSetGet) {
dim->set_key("mcc");
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
// test get and set ttl
ASSERT_EQ(10, (int64_t)table->GetIndex(0)->GetTTL()->abs_ttl / (10 * 6000));
@@ -1562,7 +1562,7 @@ TEST_P(TableTest, TraverseIteratorCountWithLimit) {
dim->set_key(row[1]);
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
TableIterator* it = table->NewTraverseIterator(0);
@@ -1669,7 +1669,7 @@ TEST_P(TableTest, TSColIDLength) {
dim1->set_key(row[0]);
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
TableIterator* it = table->NewTraverseIterator(0);
@@ -1727,7 +1727,7 @@ TEST_P(TableTest, MultiDimensionPutTS) {
dim->set_key(row[1]);
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
TableIterator* it = table->NewTraverseIterator(0);
it->SeekToFirst();
@@ -1781,7 +1781,7 @@ TEST_P(TableTest, MultiDimensionPutTS1) {
dim->set_key(row[1]);
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
TableIterator* it = table->NewTraverseIterator(0);
it->SeekToFirst();
@@ -1823,7 +1823,7 @@ TEST_P(TableTest, MultiDimissionPutTS2) {
::openmldb::codec::SDKCodec sdk_codec(meta);
std::string result;
sdk_codec.EncodeRow({"d0", "d1", "d2"}, &result);
- bool ok = table->Put(100, result, dimensions);
+ bool ok = table->Put(100, result, dimensions).ok();
ASSERT_TRUE(ok);
TableIterator* it = table->NewTraverseIterator(0);
@@ -1885,7 +1885,7 @@ TEST_P(TableTest, AbsAndLat) {
}
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
for (int i = 0; i <= 5; i++) {
@@ -1938,10 +1938,11 @@ TEST_P(TableTest, NegativeTs) {
dim->set_key(row[0]);
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- ASSERT_FALSE(table->Put(0, value, request.dimensions()));
+ auto st = table->Put(0, value, request.dimensions());
+ ASSERT_TRUE(absl::IsInvalidArgument(st)) << st.ToString();
}
-INSTANTIATE_TEST_CASE_P(TestMemAndHDD, TableTest,
+INSTANTIATE_TEST_SUITE_P(TestMemAndHDD, TableTest,
::testing::Values(::openmldb::common::kMemory, ::openmldb::common::kHDD));
} // namespace storage
diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc
index 8691cf2f90b..8b2b446e874 100644
--- a/src/tablet/tablet_impl.cc
+++ b/src/tablet/tablet_impl.cc
@@ -767,7 +767,8 @@ void TabletImpl::Put(RpcController* controller, const ::openmldb::api::PutReques
if (request->ts_dimensions_size() > 0) {
entry.mutable_ts_dimensions()->CopyFrom(request->ts_dimensions());
}
- bool ok = false;
+
+ absl::Status st;
if (request->dimensions_size() > 0) {
int32_t ret_code = CheckDimessionPut(request, table->GetIdxCnt());
if (ret_code != 0) {
@@ -776,16 +777,27 @@ void TabletImpl::Put(RpcController* controller, const ::openmldb::api::PutReques
return;
}
DLOG(INFO) << "put data to tid " << tid << " pid " << pid << " with key " << request->dimensions(0).key();
- ok = table->Put(entry.ts(), entry.value(), entry.dimensions());
+ // 1. normal put: ok, invalid data
+ // 2. put if absent: ok, exists but ignore, invalid data
+ st = table->Put(entry.ts(), entry.value(), entry.dimensions(), request->put_if_absent());
}
- if (!ok) {
+
+ if (!st.ok()) {
+ if (request->put_if_absent() && absl::IsAlreadyExists(st)) {
+ // not a failure but shounld't write log entry
+ response->set_code(::openmldb::base::ReturnCode::kOk);
+ response->set_msg("exists but ignore");
+ return;
+ }
+ LOG(WARNING) << st.ToString();
response->set_code(::openmldb::base::ReturnCode::kPutFailed);
- response->set_msg("put failed");
+ response->set_msg(st.ToString());
return;
}
response->set_code(::openmldb::base::ReturnCode::kOk);
std::shared_ptr replicator;
+ bool ok = false;
do {
replicator = GetReplicator(request->tid(), request->pid());
if (!replicator) {
@@ -2220,9 +2232,8 @@ void TabletImpl::AppendEntries(RpcController* controller, const ::openmldb::api:
return;
}
if (entry.has_method_type() && entry.method_type() == ::openmldb::api::MethodType::kDelete) {
- table->Delete(entry);
- }
- if (!table->Put(entry)) {
+ table->Delete(entry); // TODO(hw): error handle
+ } else if (!table->Put(entry)) { // put if type is not delete
PDLOG(WARNING, "fail to put entry. tid %u pid %u", tid, pid);
response->set_code(::openmldb::base::ReturnCode::kFailToAppendEntriesToReplicator);
response->set_msg("fail to append entry to table");
diff --git a/src/tablet/tablet_impl_func_test.cc b/src/tablet/tablet_impl_func_test.cc
index c84729f288d..c07084a396d 100644
--- a/src/tablet/tablet_impl_func_test.cc
+++ b/src/tablet/tablet_impl_func_test.cc
@@ -89,7 +89,7 @@ void CreateBaseTable(::openmldb::storage::Table*& table, // NOLINT
dim->set_key(row[1]);
std::string value;
ASSERT_EQ(0, codec.EncodeRow(row, &value));
- ASSERT_TRUE(table->Put(0, value, request.dimensions()));
+ ASSERT_TRUE(table->Put(0, value, request.dimensions()).ok());
}
return;
}
@@ -389,7 +389,7 @@ TEST_P(TabletFuncTest, GetTimeIndex_ts1_iterator) {
RunGetTimeIndexAssert(&query_its, base_ts, base_ts - 100);
}
-INSTANTIATE_TEST_CASE_P(TabletMemAndHDD, TabletFuncTest,
+INSTANTIATE_TEST_SUITE_P(TabletMemAndHDD, TabletFuncTest,
::testing::Values(::openmldb::common::kMemory, ::openmldb::common::kHDD,
::openmldb::common::kSSD));