diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SelectIntoPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SelectIntoPlan.scala index 4824e28d0d3..7dcdd51575b 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SelectIntoPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SelectIntoPlan.scala @@ -16,7 +16,7 @@ package com._4paradigm.openmldb.batch.nodes import com._4paradigm.hybridse.vm.PhysicalSelectIntoNode -import com._4paradigm.openmldb.batch.utils.HybridseUtil +import com._4paradigm.openmldb.batch.utils.{HybridseUtil, OpenmldbTableUtil} import com._4paradigm.openmldb.batch.{PlanContext, SparkInstance} import org.slf4j.LoggerFactory @@ -45,6 +45,25 @@ object SelectIntoPlan { val dbt = HybridseUtil.hiveDest(outPath) logger.info(s"offline select into: hive way, write mode[${mode}], out table ${dbt}") input.getDf().write.format("hive").mode(mode).saveAsTable(dbt) + } else if (format == "openmldb") { + + val (db, table) = HybridseUtil.getOpenmldbDbAndTable(outPath) + + val createIfNotExists = extra.get("create_if_not_exists").get.toBoolean + if (createIfNotExists) { + logger.info("Try to create openmldb output table: " + table) + + OpenmldbTableUtil.createOpenmldbTableFromDf(ctx.getOpenmldbSession, input.getDf(), db, table) + } + + val writeOptions = Map( + "db" -> db, + "table" -> table, + "zkCluster" -> ctx.getConf.openmldbZkCluster, + "zkPath" -> ctx.getConf.openmldbZkRootPath) + + input.getDf().write.options(writeOptions).format("openmldb").mode(mode).save() + } else { logger.info("offline select into: format[{}], options[{}], write mode[{}], out path {}", format, options, mode, outPath) diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala index bdb9f30a727..f4c17dd1f09 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/HybridseUtil.scala @@ -213,6 +213,8 @@ object HybridseUtil { // load data: read format, select into: write format val format = if (file.toLowerCase().startsWith("hive://")) { "hive" + } else if (file.toLowerCase().startsWith("openmldb://")) { + "openmldb" } else { parseOption(getOptionFromNode(node, "format"), "csv", getStringOrDefault).toLowerCase } @@ -252,7 +254,11 @@ object HybridseUtil { // only for select into, "" means N/A extraOptions += ("coalesce" -> parseOption(getOptionFromNode(node, "coalesce"), "0", getIntOrDefault)) extraOptions += ("sql" -> parseOption(getOptionFromNode(node, "sql"), "", getStringOrDefault)) - extraOptions += ("writer_type") -> parseOption(getOptionFromNode(node, "writer_type"), "single", getStringOrDefault) + extraOptions += ("writer_type") -> parseOption(getOptionFromNode(node, "writer_type"), "single", + getStringOrDefault) + + extraOptions += ("create_if_not_exists" -> parseOption(getOptionFromNode(node, "create_if_not_exists"), + "true", getBoolOrDefault)) (format, options.toMap, mode, extraOptions.toMap) } @@ -451,6 +457,19 @@ object HybridseUtil { path.substring(tableStartPos) } + def getOpenmldbDbAndTable(path: String): (String, String) = { + require(path.toLowerCase.startsWith("openmldb://")) + // openmldb:// + val tableStartPos = 11 + val dbAndTableString = path.substring(tableStartPos) + + require(dbAndTableString.split("\\.").size == 2) + + val db = dbAndTableString.split("\\.")(0) + val table = dbAndTableString.split("\\.")(1) + (db, table) + } + private def hiveLoad(openmldbSession: OpenmldbSession, file: String, columns: util.List[Common.ColumnDesc], loadDataSql: String = ""): DataFrame = { if (logger.isDebugEnabled()) { diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/OpenmldbTableUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/OpenmldbTableUtil.scala index bb1727d751c..0364c2fbce5 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/OpenmldbTableUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/OpenmldbTableUtil.scala @@ -36,7 +36,7 @@ object OpenmldbTableUtil { val schema = df.schema - var createTableSql = s"CREATE TABLE $tableName (" + var createTableSql = s"CREATE TABLE IF NOT EXISTS $tableName (" schema.map(structField => { val colName = structField.name val colType = DataTypeUtil.sparkTypeToString(structField.dataType)