diff --git a/.changes/unreleased/Features-20231215-191154.yaml b/.changes/unreleased/Features-20231215-191154.yaml new file mode 100644 index 000000000..1cb8020cc --- /dev/null +++ b/.changes/unreleased/Features-20231215-191154.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support limiting get_catalog by object name +time: 2023-12-15T19:11:54.536441-05:00 +custom: + Author: mikealfare + Issue: "900" diff --git a/dbt/adapters/spark/cache.py b/dbt/adapters/spark/cache.py new file mode 100644 index 000000000..3d2f721e5 --- /dev/null +++ b/dbt/adapters/spark/cache.py @@ -0,0 +1,29 @@ +from dbt.adapters.base import BaseRelation +from dbt.adapters.cache import RelationsCache +from dbt.exceptions import MissingRelationError +from dbt.utils import lowercase + + +class SparkRelationsCache(RelationsCache): + def get_relation_from_stub(self, relation_stub: BaseRelation) -> BaseRelation: + """ + Case-insensitively yield all relations matching the given schema. + + :param BaseRelation relation_stub: The relation to look for + :return BaseRelation: The cached version of the relation + """ + with self.lock: + results = [ + relation.inner + for relation in self.relations.values() + if all( + { + lowercase(relation.database) == lowercase(relation_stub.database), + lowercase(relation.schema) == lowercase(relation_stub.schema), + lowercase(relation.identifier) == lowercase(relation_stub.identifier), + } + ) + ] + if len(results) == 0: + raise MissingRelationError(relation_stub) + return results[0] diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 16c3a3cb7..26fc0b51d 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -4,10 +4,10 @@ from typing import Any, Dict, Iterable, List, Optional, Union, Type, Tuple, Callable, Set from dbt.adapters.base.relation import InformationSchema +from dbt.adapters.capability import CapabilityDict, CapabilitySupport, Support, Capability from dbt.contracts.graph.manifest import Manifest from typing_extensions import TypeAlias - import agate import dbt @@ -19,6 +19,7 @@ from dbt.adapters.spark import SparkConnectionManager from dbt.adapters.spark import SparkRelation from dbt.adapters.spark import SparkColumn +from dbt.adapters.spark.cache import SparkRelationsCache from dbt.adapters.spark.python_submissions import ( JobClusterPythonJobHelper, AllPurposeClusterPythonJobHelper, @@ -101,12 +102,20 @@ class SparkAdapter(SQLAdapter): ConstraintType.foreign_key: ConstraintSupport.NOT_ENFORCED, } + _capabilities = CapabilityDict( + {Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full)} + ) + Relation: TypeAlias = SparkRelation RelationInfo = Tuple[str, str, str] Column: TypeAlias = SparkColumn ConnectionManager: TypeAlias = SparkConnectionManager AdapterSpecificConfigs: TypeAlias = SparkConfig + def __init__(self, config) -> None: # type: ignore + super().__init__(config) + self.cache: SparkRelationsCache = SparkRelationsCache() + @classmethod def date_function(cls) -> str: return "current_timestamp()" @@ -377,6 +386,25 @@ def get_catalog( catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions + def get_catalog_by_relations( + self, manifest: Manifest, relations: Set[BaseRelation] + ) -> Tuple[agate.Table, List[Exception]]: + with executor(self.config) as tpe: + futures: List[Future[agate.Table]] = [] + for relation in relations: + futures.append( + tpe.submit_connected( + self, + str(relation), + self._get_one_catalog_by_relations, + relation.information_schema_only(), + [relation], + manifest, + ) + ) + catalogs, exceptions = catch_as_completed(futures) + return catalogs, exceptions + def _get_one_catalog( self, information_schema: InformationSchema, @@ -390,13 +418,27 @@ def _get_one_catalog( database = information_schema.database schema = list(schemas)[0] + relations = self.list_relations(database, schema) + return self._get_relation_metadata_at_column_level(relations) + def _get_relation_metadata_at_column_level(self, relations: List[BaseRelation]) -> agate.Table: columns: List[Dict[str, Any]] = [] - for relation in self.list_relations(database, schema): + for relation in relations: logger.debug("Getting table schema for relation {}", str(relation)) columns.extend(self._get_columns_for_catalog(relation)) return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) + def _get_one_catalog_by_relations( + self, + information_schema: InformationSchema, + relations: List[BaseRelation], + manifest: Manifest, + ) -> agate.Table: + cached_relations = [ + self.cache.get_relation_from_stub(relation_stub) for relation_stub in relations + ] + return self._get_relation_metadata_at_column_level(cached_relations) + def check_schema_exists(self, database: str, schema: str) -> bool: results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database})