diff --git a/feathr_project/feathr/client.py b/feathr_project/feathr/client.py index a60205b25..4c7a97e6a 100644 --- a/feathr_project/feathr/client.py +++ b/feathr_project/feathr/client.py @@ -236,9 +236,9 @@ def get_snowflake_path(self, database: str, schema: str, dbtable: str = None, qu if dbtable is None and query is None: raise RuntimeError("One of dbtable or query must be specified..") if dbtable: - return f"snowflake://snowflake_account/?sfDatabase={database}&fSchema={schema}&dbtable={dbtable}" + return f"snowflake://snowflake_account/?sfDatabase={database}&sfSchema={schema}&dbtable={dbtable}" else: - return f"snowflake://snowflake_account/?sfDatabase={database}&fSchema={schema}&query={query}" + return f"snowflake://snowflake_account/?sfDatabase={database}&sfSchema={schema}&query={query}" def list_registered_features(self, project_name: str = None) -> List[str]: """List all the already registered features under the given project. diff --git a/feathr_project/feathr/registry/_feathr_registry_client_hack.py b/feathr_project/feathr/registry/_feathr_registry_client_aws.py similarity index 91% rename from feathr_project/feathr/registry/_feathr_registry_client_hack.py rename to feathr_project/feathr/registry/_feathr_registry_client_aws.py index bfd28305b..538210dc5 100644 --- a/feathr_project/feathr/registry/_feathr_registry_client_hack.py +++ b/feathr_project/feathr/registry/_feathr_registry_client_aws.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict import json import logging import requests @@ -12,7 +12,7 @@ def check(r): return r -class _FeatureRegistryHack(_FeatureRegistry): +class _FeatureRegistryAWS(_FeatureRegistry): def __init__(self, project_name: str, endpoint: str, project_tags: Dict[str, str] = None, credential=None, config_path=None): self.project_name = project_name diff --git a/feathr_project/feathr/registry/feature_registry.py b/feathr_project/feathr/registry/feature_registry.py index 97fbb2bcf..1d0bf9596 100644 --- a/feathr_project/feathr/registry/feature_registry.py +++ b/feathr_project/feathr/registry/feature_registry.py @@ -55,13 +55,13 @@ def save_to_feature_config_from_context(self, anchor_list, derived_feature_list, def default_registry_client(project_name: str, config_path:str = "./feathr_config.yaml", project_registry_tag: Dict[str, str]=None, credential = None) -> FeathrRegistry: from feathr.registry._feathr_registry_client import _FeatureRegistry from feathr.registry._feature_registry_purview import _PurviewRegistry - from feathr.registry._feathr_registry_client_hack import _FeatureRegistryHack + from feathr.registry._feathr_registry_client_aws import _FeatureRegistryAWS from aws_requests_auth.aws_auth import AWSRequestsAuth envutils = _EnvVaraibleUtil(config_path) registry_endpoint = envutils.get_environment_variable_with_default("feature_registry", "api_endpoint") if registry_endpoint and isinstance(credential, AWSRequestsAuth): - return _FeatureRegistryHack(project_name, endpoint=registry_endpoint, project_tags=project_registry_tag, credential=credential) + return _FeatureRegistryAWS(project_name, endpoint=registry_endpoint, project_tags=project_registry_tag, credential=credential) elif registry_endpoint: return _FeatureRegistry(project_name, endpoint=registry_endpoint, project_tags=project_registry_tag, credential=credential) else: diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala index 37bece6b8..83bb093a3 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"), new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"), new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"), + new JsonSubTypes.Type(value = classOf[Snowflake], name = "snowflake"), )) trait DataLocation { /** diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/jdbc/JdbcConnectorChooser.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/jdbc/JdbcConnectorChooser.scala index d96648122..f35117844 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/jdbc/JdbcConnectorChooser.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/jdbc/JdbcConnectorChooser.scala @@ -15,13 +15,11 @@ sealed trait JdbcConnectorChooser object JdbcConnectorChooser { case object SqlServer extends JdbcConnectorChooser case object Postgres extends JdbcConnectorChooser - case object SnowflakeSql extends JdbcConnectorChooser case object DefaultJDBC extends JdbcConnectorChooser def getType (url: String): JdbcConnectorChooser = url match { case url if url.startsWith("jdbc:sqlserver") => SqlServer case url if url.startsWith("jdbc:postgresql:") => Postgres - case url if url.startsWith("jdbc:snowflake:") => SnowflakeSql case _ => DefaultJDBC } @@ -29,7 +27,6 @@ object JdbcConnectorChooser { val sqlDbType = getType(url) val dataLoader = sqlDbType match { case SqlServer => new SqlServerDataLoader(ss) - case SnowflakeSql => new SnowflakeSqlDataLoader(ss) case _ => new SqlServerDataLoader(ss) //default jdbc data loader place holder } dataLoader diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/jdbc/SnowflakeSqlDataLoader.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/jdbc/SnowflakeSqlDataLoader.scala deleted file mode 100644 index 312caad7c..000000000 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/jdbc/SnowflakeSqlDataLoader.scala +++ /dev/null @@ -1,67 +0,0 @@ -package com.linkedin.feathr.offline.source.dataloader.jdbc - -import org.apache.commons.httpclient.URI -import org.apache.http.client.utils.URLEncodedUtils -import org.apache.spark.sql.{DataFrame, DataFrameReader, SparkSession} - -import scala.collection.JavaConverters.asScalaBufferConverter -import java.nio.charset.Charset - -/** - * This is used for Snowflake data source JDBC connector - * - */ -class SnowflakeSqlDataLoader(ss: SparkSession) extends JdbcConnector(ss) { - val SNOWFLAKE_SOURCE_NAME = "net.snowflake.spark.snowflake" - - override def getDFReader(jdbcOptions: Map[String, String], url: String): DataFrameReader = { - val dfReader = _ss.read - .format(SNOWFLAKE_SOURCE_NAME) - .options(jdbcOptions) - - val uri = new URI(url) - val charset = Charset.forName("UTF-8") - val params = URLEncodedUtils.parse(uri.getQuery, charset).asScala - params.foreach(x => { - dfReader.option(x.getName, x.getValue) - }) - dfReader - } - - override def extractJdbcOptions(ss: SparkSession, url: String): Map[String, String] = { - val jdbcOptions1 = getJdbcParams(ss) - val jdbcOptions2 = getJdbcAuth(ss) - jdbcOptions1 ++ jdbcOptions2 - } - - def getJdbcParams(ss: SparkSession): Map[String, String] = { - Map[String, String]( - "sfURL" -> ss.conf.get("sfURL"), - "sfUser" -> ss.conf.get("sfUser"), - "sfRole" -> ss.conf.get("sfRole"), - ) - } - - def getJdbcAuth(ss: SparkSession): Map[String, String] = { - // If user set password, then we use password to auth - ss.conf.getOption("sfPassword") match { - case Some(_) => - Map[String, String]( - "sfUser" -> ss.conf.get("sfUser"), - "sfRole" -> ss.conf.get("sfRole"), - "sfPassword" -> ss.conf.get("sfPassword"), - ) - case _ => { - // TODO Add token support - Map[String, String]() - } - } - } - - override def loadDataFrame(url: String, jdbcOptions: Map[String, String] = Map[String, String]()): DataFrame = { - val sparkReader = getDFReader(jdbcOptions, url) - sparkReader - .option("url", url) - .load() - } -} \ No newline at end of file diff --git a/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala b/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala index e9d3a2bf1..1604e174c 100644 --- a/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala @@ -15,7 +15,7 @@ import com.linkedin.feathr.offline.source.SourceFormatType import com.linkedin.feathr.offline.source.SourceFormatType.SourceFormatType import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import com.linkedin.feathr.offline.source.dataloader.hdfs.FileFormat -import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils +import com.linkedin.feathr.offline.source.dataloader.jdbc.{JdbcUtils, SnowflakeUtils} import com.linkedin.feathr.offline.source.pathutil.{PathChecker, TimeBasedHdfsPathAnalyzer, TimeBasedHdfsPathGenerator} import com.linkedin.feathr.offline.util.AclCheckUtils.getLatestPath import com.linkedin.feathr.offline.util.datetime.OfflineDateTimeUtils @@ -648,6 +648,9 @@ private[offline] object SourceUtils { case FileFormat.JDBC => { JdbcUtils.loadDataFrame(ss, inputData.inputPath) } + case FileFormat.SNOWFLAKE => { + SnowflakeUtils.loadDataFrame(ss, inputData.inputPath) + } case FileFormat.CSV => { ss.read.format("csv").option("header", "true").option("delimiter", csvDelimiterOption).load(inputData.inputPath) }