Skip to content

Commit

Permalink
update get_snowflake_path
Browse files Browse the repository at this point in the history
  • Loading branch information
aabbasi-hbo committed Nov 4, 2022
1 parent 6fa8596 commit b71b415
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 77 deletions.
4 changes: 2 additions & 2 deletions feathr_project/feathr/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ def get_snowflake_path(self, database: str, schema: str, dbtable: str = None, qu
if dbtable is None and query is None:
raise RuntimeError("One of dbtable or query must be specified..")
if dbtable:
return f"snowflake://snowflake_account/?sfDatabase={database}&fSchema={schema}&dbtable={dbtable}"
return f"snowflake://snowflake_account/?sfDatabase={database}&sfSchema={schema}&dbtable={dbtable}"
else:
return f"snowflake://snowflake_account/?sfDatabase={database}&fSchema={schema}&query={query}"
return f"snowflake://snowflake_account/?sfDatabase={database}&sfSchema={schema}&query={query}"

def list_registered_features(self, project_name: str = None) -> List[str]:
"""List all the already registered features under the given project.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict
import json
import logging
import requests
Expand All @@ -12,7 +12,7 @@ def check(r):
return r


class _FeatureRegistryHack(_FeatureRegistry):
class _FeatureRegistryAWS(_FeatureRegistry):
def __init__(self, project_name: str, endpoint: str, project_tags: Dict[str, str] = None, credential=None,
config_path=None):
self.project_name = project_name
Expand Down
4 changes: 2 additions & 2 deletions feathr_project/feathr/registry/feature_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def save_to_feature_config_from_context(self, anchor_list, derived_feature_list,
def default_registry_client(project_name: str, config_path:str = "./feathr_config.yaml", project_registry_tag: Dict[str, str]=None, credential = None) -> FeathrRegistry:
from feathr.registry._feathr_registry_client import _FeatureRegistry
from feathr.registry._feature_registry_purview import _PurviewRegistry
from feathr.registry._feathr_registry_client_hack import _FeatureRegistryHack
from feathr.registry._feathr_registry_client_aws import _FeatureRegistryAWS
from aws_requests_auth.aws_auth import AWSRequestsAuth

envutils = _EnvVaraibleUtil(config_path)
registry_endpoint = envutils.get_environment_variable_with_default("feature_registry", "api_endpoint")
if registry_endpoint and isinstance(credential, AWSRequestsAuth):
return _FeatureRegistryHack(project_name, endpoint=registry_endpoint, project_tags=project_registry_tag, credential=credential)
return _FeatureRegistryAWS(project_name, endpoint=registry_endpoint, project_tags=project_registry_tag, credential=credential)
elif registry_endpoint:
return _FeatureRegistry(project_name, endpoint=registry_endpoint, project_tags=project_registry_tag, credential=credential)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"),
new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"),
new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"),
new JsonSubTypes.Type(value = classOf[Snowflake], name = "snowflake"),
))
trait DataLocation {
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@ sealed trait JdbcConnectorChooser
object JdbcConnectorChooser {
case object SqlServer extends JdbcConnectorChooser
case object Postgres extends JdbcConnectorChooser
case object SnowflakeSql extends JdbcConnectorChooser
case object DefaultJDBC extends JdbcConnectorChooser

def getType (url: String): JdbcConnectorChooser = url match {
case url if url.startsWith("jdbc:sqlserver") => SqlServer
case url if url.startsWith("jdbc:postgresql:") => Postgres
case url if url.startsWith("jdbc:snowflake:") => SnowflakeSql
case _ => DefaultJDBC
}

def getJdbcConnector(ss: SparkSession, url: String): JdbcConnector = {
val sqlDbType = getType(url)
val dataLoader = sqlDbType match {
case SqlServer => new SqlServerDataLoader(ss)
case SnowflakeSql => new SnowflakeSqlDataLoader(ss)
case _ => new SqlServerDataLoader(ss) //default jdbc data loader place holder
}
dataLoader
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import com.linkedin.feathr.offline.source.SourceFormatType
import com.linkedin.feathr.offline.source.SourceFormatType.SourceFormatType
import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler
import com.linkedin.feathr.offline.source.dataloader.hdfs.FileFormat
import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils
import com.linkedin.feathr.offline.source.dataloader.jdbc.{JdbcUtils, SnowflakeUtils}
import com.linkedin.feathr.offline.source.pathutil.{PathChecker, TimeBasedHdfsPathAnalyzer, TimeBasedHdfsPathGenerator}
import com.linkedin.feathr.offline.util.AclCheckUtils.getLatestPath
import com.linkedin.feathr.offline.util.datetime.OfflineDateTimeUtils
Expand Down Expand Up @@ -648,6 +648,9 @@ private[offline] object SourceUtils {
case FileFormat.JDBC => {
JdbcUtils.loadDataFrame(ss, inputData.inputPath)
}
case FileFormat.SNOWFLAKE => {
SnowflakeUtils.loadDataFrame(ss, inputData.inputPath)
}
case FileFormat.CSV => {
ss.read.format("csv").option("header", "true").option("delimiter", csvDelimiterOption).load(inputData.inputPath)
}
Expand Down

0 comments on commit b71b415

Please sign in to comment.