diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index b4ef2180d71d45..7d63f41f4bcf03 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -5,6 +5,7 @@ import os import os.path import platform +import re from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Union @@ -33,6 +34,7 @@ from datahub.ingestion.source.snowflake.constants import ( GENERIC_PERMISSION_ERROR_KEY, SnowflakeEdition, + SnowflakeObjectDomain, ) from datahub.ingestion.source.snowflake.snowflake_assertion import ( SnowflakeAssertionsHandler, @@ -162,6 +164,8 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.data_dictionary = SnowflakeDataDictionary(connection=self.connection) self.lineage_extractor: Optional[SnowflakeLineageExtractor] = None + self.discovered_datasets: Optional[List[str]] = None + self.aggregator: SqlParsingAggregator = self._exit_stack.enter_context( SqlParsingAggregator( platform=self.identifiers.platform, @@ -182,6 +186,8 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): generate_usage_statistics=False, generate_operations=False, format_queries=self.config.format_sql_queries, + is_temp_table=self._is_temp_table, + is_allowed_table=self._is_allowed_table, ) ) self.report.sql_aggregator = self.aggregator.report @@ -444,6 +450,34 @@ class SnowflakePrivilege: return _report + def _is_temp_table(self, name: str) -> bool: + if any( + re.match(pattern, name, flags=re.IGNORECASE) + for pattern in self.config.temporary_tables_pattern + ): + return True + + # This is also a temp table if + # 1. this name would be allowed by the dataset patterns, and + # 2. we have a list of discovered tables, and + # 3. it's not in the discovered tables list + if ( + self.filters.is_dataset_pattern_allowed(name, SnowflakeObjectDomain.TABLE) + and self.discovered_datasets + and name not in self.discovered_datasets + ): + return True + + return False + + def _is_allowed_table(self, name: str) -> bool: + if self.discovered_datasets and name not in self.discovered_datasets: + return False + + return self.filters.is_dataset_pattern_allowed( + name, SnowflakeObjectDomain.TABLE + ) + def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: return [ *super().get_workunit_processors(), @@ -513,7 +547,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ) return - discovered_datasets = discovered_tables + discovered_views + self.discovered_datasets = discovered_tables + discovered_views if self.config.use_queries_v2: with self.report.new_stage(f"*: {VIEW_PARSING}"): @@ -538,13 +572,14 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: filters=self.filters, identifiers=self.identifiers, schema_resolver=schema_resolver, - discovered_tables=discovered_datasets, + discovered_tables=self.discovered_datasets, graph=self.ctx.graph, ) # TODO: This is slightly suboptimal because we create two SqlParsingAggregator instances with different configs # but a shared schema resolver. That's fine for now though - once we remove the old lineage/usage extractors, # it should be pretty straightforward to refactor this and only initialize the aggregator once. + # This also applies for the _is_temp_table and _is_allowed_table methods above, duplicated from SnowflakeQueriesExtractor. self.report.queries_extractor = queries_extractor.report yield from queries_extractor.get_workunits_internal() queries_extractor.close() @@ -568,12 +603,14 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if ( self.config.include_usage_stats or self.config.include_operational_stats ) and self.usage_extractor: - yield from self.usage_extractor.get_usage_workunits(discovered_datasets) + yield from self.usage_extractor.get_usage_workunits( + self.discovered_datasets + ) if self.config.include_assertion_results: yield from SnowflakeAssertionsHandler( self.config, self.report, self.connection, self.identifiers - ).get_assertion_workunits(discovered_datasets) + ).get_assertion_workunits(self.discovered_datasets) self.connection.close()