forked from sodadata/soda-core
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: profiler implementation (sodadata#1775)
The purpose of this PR is to refactor numeric and text column profiler implementation. The new approach has more modular and easy-to-maintain code. Lots of nested if/else and for loop logic have been refactored. Apart from unit tests, to make sure that there is nothing changed in the cloud, I also made e2e tests, and the screenshots are attached.
- Loading branch information
Showing
7 changed files
with
428 additions
and
333 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from soda.execution.query.query import Query | ||
from soda.profiling.profile_columns_result import ProfileColumnsResultColumn | ||
|
||
if TYPE_CHECKING: | ||
from soda.execution.data_source_scan import DataSourceScan | ||
from soda.sodacl.data_source_check_cfg import ProfileColumnsCfg | ||
|
||
|
||
class NumericColumnProfiler: | ||
def __init__( | ||
self, | ||
data_source_scan: DataSourceScan, | ||
profile_columns_cfg: ProfileColumnsCfg, | ||
table_name: str, | ||
column_name: str, | ||
column_data_type: str, | ||
) -> None: | ||
self.data_source_scan = data_source_scan | ||
self.data_source = data_source_scan.data_source | ||
self.logs = data_source_scan.scan._logs | ||
self.profile_columns_cfg = profile_columns_cfg | ||
self.table_name = table_name | ||
self.column_name = column_name | ||
self.column_data_type = column_data_type | ||
self.result_column = ProfileColumnsResultColumn(column_name=column_name, column_data_type=column_data_type) | ||
|
||
def profile(self) -> ProfileColumnsResultColumn: | ||
self.logs.debug(f"Profiling column {self.column_name} of {self.table_name}") | ||
|
||
# mins, maxs, min, max, frequent values | ||
self._set_result_column_value_frequency_attributes() | ||
|
||
# Average, sum, variance, standard deviation, distinct values, missing values | ||
self._set_result_column_aggregation_attributes() | ||
|
||
# histogram | ||
self._set_result_column_histogram_attributes() | ||
return self.result_column | ||
|
||
def _set_result_column_value_frequency_attributes(self) -> None: | ||
value_frequencies = self._compute_value_frequency() | ||
if value_frequencies: | ||
self.result_column.set_min_max_metrics(value_frequencies=value_frequencies) | ||
self.result_column.set_frequency_metric(value_frequencies=value_frequencies) | ||
else: | ||
self.logs.error( | ||
"Database returned no results for minumum values, maximum values and " | ||
f"frequent values in table: {self.table_name}, columns: {self.column_name}" | ||
) | ||
|
||
def _set_result_column_aggregation_attributes(self) -> None: | ||
aggregated_metrics = self._compute_aggregated_metrics() | ||
if aggregated_metrics: | ||
self.result_column.set_numeric_aggregation_metrics(aggregated_metrics=aggregated_metrics) | ||
else: | ||
self.logs.error( | ||
f"Database returned no results for aggregates in table: {self.table_name}, columns: {self.column_name}" | ||
) | ||
|
||
def _set_result_column_histogram_attributes(self) -> None: | ||
histogram_values = self._compute_histogram() | ||
if histogram_values: | ||
self.result_column.set_histogram(histogram_values=histogram_values) | ||
else: | ||
self.logs.error( | ||
f"Database returned no results for histograms in table: {self.table_name}, columns: {self.column_name}" | ||
) | ||
|
||
def _compute_value_frequency(self) -> list[tuple] | None: | ||
value_frequencies_sql = self.data_source.profiling_sql_values_frequencies_query( | ||
"numeric", | ||
self.table_name, | ||
self.column_name, | ||
self.profile_columns_cfg.limit_mins_maxs, | ||
self.profile_columns_cfg.limit_frequent_values, | ||
) | ||
|
||
value_frequencies_query = Query( | ||
data_source_scan=self.data_source_scan, | ||
unqualified_query_name=f"profiling-{self.table_name}-{self.column_name}-value-frequencies-numeric", | ||
sql=value_frequencies_sql, | ||
) | ||
value_frequencies_query.execute() | ||
rows = value_frequencies_query.rows | ||
return rows | ||
|
||
def _compute_aggregated_metrics(self) -> list[tuple] | None: | ||
aggregates_sql = self.data_source.profiling_sql_aggregates_numeric(self.table_name, self.column_name) | ||
aggregates_query = Query( | ||
data_source_scan=self.data_source_scan, | ||
unqualified_query_name=f"profiling-{self.table_name}-{self.column_name}-profiling-aggregates", | ||
sql=aggregates_sql, | ||
) | ||
aggregates_query.execute() | ||
rows = aggregates_query.rows | ||
return rows | ||
|
||
def _compute_histogram(self) -> None | dict[str, list]: | ||
if self.result_column.min is None: | ||
self.logs.warning("Min cannot be None, make sure the min metric is derived before histograms") | ||
if self.result_column.max is None: | ||
self.logs.warning("Max cannot be None, make sure the min metric is derived before histograms") | ||
if self.result_column.distinct_values is None: | ||
self.logs.warning( | ||
"Distinct values cannot be None, make sure the distinct values metric is derived before histograms" | ||
) | ||
if ( | ||
self.result_column.min is None | ||
or self.result_column.max is None | ||
or self.result_column.distinct_values is None | ||
): | ||
self.logs.warning( | ||
f"Histogram query for {self.table_name}, column {self.column_name} skipped. See earlier warnings." | ||
) | ||
return None | ||
|
||
histogram_sql, bins_list = self.data_source.histogram_sql_and_boundaries( | ||
table_name=self.table_name, | ||
column_name=self.column_name, | ||
min_value=self.result_column.min, | ||
max_value=self.result_column.max, | ||
n_distinct=self.result_column.distinct_values, | ||
column_type=self.column_data_type, | ||
) | ||
if histogram_sql is None: | ||
return None | ||
|
||
histogram_query = Query( | ||
data_source_scan=self.data_source_scan, | ||
unqualified_query_name=f"profiling-{self.table_name}-{self.column_name}-histogram", | ||
sql=histogram_sql, | ||
) | ||
histogram_query.execute() | ||
histogram_values = histogram_query.rows | ||
|
||
if histogram_values is None: | ||
return None | ||
histogram = {} | ||
histogram["boundaries"] = bins_list | ||
histogram["frequencies"] = [int(freq) if freq is not None else 0 for freq in histogram_values[0]] | ||
return histogram |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from soda.execution.query.query import Query | ||
from soda.profiling.profile_columns_result import ProfileColumnsResultColumn | ||
|
||
if TYPE_CHECKING: | ||
from soda.execution.data_source_scan import DataSourceScan | ||
from soda.sodacl.data_source_check_cfg import ProfileColumnsCfg | ||
|
||
|
||
class TextColumnProfiler: | ||
def __init__( | ||
self, | ||
data_source_scan: DataSourceScan, | ||
profile_columns_cfg: ProfileColumnsCfg, | ||
table_name: str, | ||
column_name: str, | ||
column_data_type: str, | ||
) -> None: | ||
self.data_source_scan = data_source_scan | ||
self.data_source = data_source_scan.data_source | ||
self.logs = data_source_scan.scan._logs | ||
self.profile_columns_cfg = profile_columns_cfg | ||
self.table_name = table_name | ||
self.column_name = column_name | ||
self.column_data_type = column_data_type | ||
self.result_column = ProfileColumnsResultColumn(column_name=column_name, column_data_type=column_data_type) | ||
|
||
def profile(self) -> ProfileColumnsResultColumn: | ||
self.logs.debug(f"Profiling column {self.column_name} of {self.table_name}") | ||
|
||
# frequent values for text column | ||
self._set_result_column_value_frequency_attribute() | ||
|
||
# pure text aggregates | ||
self._set_result_column_text_aggregation_attributes() | ||
|
||
return self.result_column | ||
|
||
def _set_result_column_value_frequency_attribute(self) -> None: | ||
value_frequencies = self._compute_value_frequency() | ||
if value_frequencies: | ||
self.result_column.set_frequency_metric(value_frequencies) | ||
else: | ||
self.logs.warning( | ||
f"Database returned no results for textual frequent values in {self.table_name}, column: {self.column_name}" | ||
) | ||
|
||
def _set_result_column_text_aggregation_attributes(self) -> None: | ||
text_aggregates = self._compute_text_aggregates() | ||
if text_aggregates: | ||
self.result_column.set_text_aggregation_metrics(text_aggregates) | ||
else: | ||
self.logs.warning( | ||
f"Database returned no results for textual aggregates in {self.table_name}, column: {self.column_name}" | ||
) | ||
|
||
def _compute_value_frequency(self) -> list[tuple] | None: | ||
# frequent values for text column | ||
value_frequencies_sql = self.data_source.profiling_sql_values_frequencies_query( | ||
"text", | ||
self.table_name, | ||
self.column_name, | ||
self.profile_columns_cfg.limit_mins_maxs, | ||
self.profile_columns_cfg.limit_frequent_values, | ||
) | ||
value_frequencies_query = Query( | ||
data_source_scan=self.data_source_scan, | ||
unqualified_query_name=f"profiling-{self.table_name}-{self.column_name}-value-frequencies-text", | ||
sql=value_frequencies_sql, | ||
) | ||
value_frequencies_query.execute() | ||
frequency_rows = value_frequencies_query.rows | ||
return frequency_rows | ||
|
||
def _compute_text_aggregates(self) -> list[tuple] | None: | ||
text_aggregates_sql = self.data_source.profiling_sql_aggregates_text(self.table_name, self.column_name) | ||
text_aggregates_query = Query( | ||
data_source_scan=self.data_source_scan, | ||
unqualified_query_name=f"profiling: {self.table_name}, {self.column_name}: get textual aggregates", | ||
sql=text_aggregates_sql, | ||
) | ||
text_aggregates_query.execute() | ||
text_aggregates_rows = text_aggregates_query.rows | ||
return text_aggregates_rows |
Oops, something went wrong.