Skip to content

Commit

Permalink
refactor: profiler implementation (sodadata#1775)
Browse files Browse the repository at this point in the history
The purpose of this PR is to refactor numeric and text column profiler implementation. The new approach has more modular and easy-to-maintain code. Lots of nested if/else and for loop logic have been refactored.

Apart from unit tests, to make sure that there is nothing changed in the cloud, I also made e2e tests, and the screenshots are attached.
  • Loading branch information
baturayo authored Feb 5, 2023
1 parent fff5f0c commit 13f3e20
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 333 deletions.
338 changes: 62 additions & 276 deletions soda/core/soda/execution/check/profile_columns_run.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions soda/core/soda/execution/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def get_tables_columns_metadata(
include_patterns: list[dict[str, str]] | None = None,
exclude_patterns: list[dict[str, str]] | None = None,
table_names_only: bool = False,
) -> defaultdict[str, dict[str, str]] | list[str] | None:
) -> defaultdict[str, dict[str, str]] | None:
# TODO: save/cache the result for later use.
if (not include_patterns) and (not exclude_patterns):
return []
Expand All @@ -447,7 +447,7 @@ def get_tables_columns_metadata(
if table_names_only:
query_result = [self._optionally_quote_table_name_from_meta_data(row[0]) for row in rows]
else:
query_result: defaultdict(dict) = self.parse_tables_columns_query(rows)
query_result: defaultdict[dict] = self.parse_tables_columns_query(rows)
return query_result
return None

Expand Down Expand Up @@ -856,8 +856,8 @@ def profiling_sql_aggregates_numeric(self, table_name: str, column_name: str) ->
SELECT
avg({column_name}) as average
, sum({column_name}) as sum
, variance({column_name}) as variance
, stddev({column_name}) as standard_deviation
, var_samp({column_name}) as variance
, stddev_samp({column_name}) as standard_deviation
, count(distinct({column_name})) as distinct_values
, sum(case when {column_name} is null then 1 else 0 end) as missing_values
FROM {qualified_table_name}
Expand Down
145 changes: 145 additions & 0 deletions soda/core/soda/profiling/numeric_column_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from soda.execution.query.query import Query
from soda.profiling.profile_columns_result import ProfileColumnsResultColumn

if TYPE_CHECKING:
from soda.execution.data_source_scan import DataSourceScan
from soda.sodacl.data_source_check_cfg import ProfileColumnsCfg


class NumericColumnProfiler:
def __init__(
self,
data_source_scan: DataSourceScan,
profile_columns_cfg: ProfileColumnsCfg,
table_name: str,
column_name: str,
column_data_type: str,
) -> None:
self.data_source_scan = data_source_scan
self.data_source = data_source_scan.data_source
self.logs = data_source_scan.scan._logs
self.profile_columns_cfg = profile_columns_cfg
self.table_name = table_name
self.column_name = column_name
self.column_data_type = column_data_type
self.result_column = ProfileColumnsResultColumn(column_name=column_name, column_data_type=column_data_type)

def profile(self) -> ProfileColumnsResultColumn:
self.logs.debug(f"Profiling column {self.column_name} of {self.table_name}")

# mins, maxs, min, max, frequent values
self._set_result_column_value_frequency_attributes()

# Average, sum, variance, standard deviation, distinct values, missing values
self._set_result_column_aggregation_attributes()

# histogram
self._set_result_column_histogram_attributes()
return self.result_column

def _set_result_column_value_frequency_attributes(self) -> None:
value_frequencies = self._compute_value_frequency()
if value_frequencies:
self.result_column.set_min_max_metrics(value_frequencies=value_frequencies)
self.result_column.set_frequency_metric(value_frequencies=value_frequencies)
else:
self.logs.error(
"Database returned no results for minumum values, maximum values and "
f"frequent values in table: {self.table_name}, columns: {self.column_name}"
)

def _set_result_column_aggregation_attributes(self) -> None:
aggregated_metrics = self._compute_aggregated_metrics()
if aggregated_metrics:
self.result_column.set_numeric_aggregation_metrics(aggregated_metrics=aggregated_metrics)
else:
self.logs.error(
f"Database returned no results for aggregates in table: {self.table_name}, columns: {self.column_name}"
)

def _set_result_column_histogram_attributes(self) -> None:
histogram_values = self._compute_histogram()
if histogram_values:
self.result_column.set_histogram(histogram_values=histogram_values)
else:
self.logs.error(
f"Database returned no results for histograms in table: {self.table_name}, columns: {self.column_name}"
)

def _compute_value_frequency(self) -> list[tuple] | None:
value_frequencies_sql = self.data_source.profiling_sql_values_frequencies_query(
"numeric",
self.table_name,
self.column_name,
self.profile_columns_cfg.limit_mins_maxs,
self.profile_columns_cfg.limit_frequent_values,
)

value_frequencies_query = Query(
data_source_scan=self.data_source_scan,
unqualified_query_name=f"profiling-{self.table_name}-{self.column_name}-value-frequencies-numeric",
sql=value_frequencies_sql,
)
value_frequencies_query.execute()
rows = value_frequencies_query.rows
return rows

def _compute_aggregated_metrics(self) -> list[tuple] | None:
aggregates_sql = self.data_source.profiling_sql_aggregates_numeric(self.table_name, self.column_name)
aggregates_query = Query(
data_source_scan=self.data_source_scan,
unqualified_query_name=f"profiling-{self.table_name}-{self.column_name}-profiling-aggregates",
sql=aggregates_sql,
)
aggregates_query.execute()
rows = aggregates_query.rows
return rows

def _compute_histogram(self) -> None | dict[str, list]:
if self.result_column.min is None:
self.logs.warning("Min cannot be None, make sure the min metric is derived before histograms")
if self.result_column.max is None:
self.logs.warning("Max cannot be None, make sure the min metric is derived before histograms")
if self.result_column.distinct_values is None:
self.logs.warning(
"Distinct values cannot be None, make sure the distinct values metric is derived before histograms"
)
if (
self.result_column.min is None
or self.result_column.max is None
or self.result_column.distinct_values is None
):
self.logs.warning(
f"Histogram query for {self.table_name}, column {self.column_name} skipped. See earlier warnings."
)
return None

histogram_sql, bins_list = self.data_source.histogram_sql_and_boundaries(
table_name=self.table_name,
column_name=self.column_name,
min_value=self.result_column.min,
max_value=self.result_column.max,
n_distinct=self.result_column.distinct_values,
column_type=self.column_data_type,
)
if histogram_sql is None:
return None

histogram_query = Query(
data_source_scan=self.data_source_scan,
unqualified_query_name=f"profiling-{self.table_name}-{self.column_name}-histogram",
sql=histogram_sql,
)
histogram_query.execute()
histogram_values = histogram_query.rows

if histogram_values is None:
return None
histogram = {}
histogram["boundaries"] = bins_list
histogram["frequencies"] = [int(freq) if freq is not None else 0 for freq in histogram_values[0]]
return histogram
70 changes: 61 additions & 9 deletions soda/core/soda/profiling/profile_columns_result.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

from numbers import Number
from typing import Any

from soda.sodacl.data_source_check_cfg import ProfileColumnsCfg


class ProfileColumnsResultColumn:
def __init__(self, column_name: str, column_type: str):
def __init__(self, column_name: str, column_data_type: str):
self.column_name: str = column_name
self.column_data_type: str = column_data_type
self.mins: list[float | int] | None = None
self.maxs: list[float | int] | None = None
self.min: float | int | None = None
Expand All @@ -22,6 +26,60 @@ def __init__(self, column_name: str, column_type: str):
self.min_length: float | None = None
self.max_length: float | None = None

def set_min_max_metrics(self, value_frequencies: list[tuple]) -> None:
self.mins = [self.unify_type(row[2]) for row in value_frequencies if row[0] == "mins"]
self.maxs = [self.unify_type(row[2]) for row in value_frequencies if row[0] == "maxs"]
self.min = self.mins[0]
self.max = self.maxs[0]

def set_frequency_metric(self, value_frequencies: list[tuple]) -> None:
self.frequent_values = [
{"value": str(row[2]), "frequency": int(row[3])} for row in value_frequencies if row[0] == "frequent_values"
]

def set_numeric_aggregation_metrics(self, aggregated_metrics: list[tuple]) -> None:
self.average = self.cast_float_dtype_handle_none(aggregated_metrics[0][0])
self.sum = self.cast_float_dtype_handle_none(aggregated_metrics[0][1])
self.variance = self.cast_float_dtype_handle_none(aggregated_metrics[0][2])
self.standard_deviation = self.cast_float_dtype_handle_none(aggregated_metrics[0][3])
self.distinct_values = self.cast_int_dtype_handle_none(aggregated_metrics[0][4])
self.missing_values = self.cast_int_dtype_handle_none(aggregated_metrics[0][5])

def set_histogram(self, histogram_values: dict[str, list]) -> None:
self.histogram = histogram_values

def set_text_aggregation_metrics(self, aggregated_metrics: list[tuple]) -> None:
self.distinct_values = self.cast_int_dtype_handle_none(aggregated_metrics[0][0])
self.missing_values = self.cast_int_dtype_handle_none(aggregated_metrics[0][1])
# TODO: after the discussion, we should change the type of the average_length to float
# CLOUD-2764
self.average_length = self.cast_int_dtype_handle_none(aggregated_metrics[0][2])
self.min_length = self.cast_int_dtype_handle_none(aggregated_metrics[0][3])
self.max_length = self.cast_int_dtype_handle_none(aggregated_metrics[0][4])

@staticmethod
def unify_type(v: Any) -> Any:
if isinstance(v, Number):
return float(v)
else:
return v

@staticmethod
def cast_float_dtype_handle_none(value: float | None) -> float | None:
if value is None:
return None
# TODO: after the discussion, we should round float values upto n decimal places
# CLOUD-2765
cast_value = float(value)
return cast_value

@staticmethod
def cast_int_dtype_handle_none(value: int | None) -> int | None:
if value is None:
return None
cast_value = int(value)
return cast_value

def get_cloud_dict(self) -> dict:
cloud_dict = {
"columnName": self.column_name,
Expand Down Expand Up @@ -75,10 +133,8 @@ def __init__(self, table_name: str, data_source: str, row_count: int | None = No
self.row_count: int | None = row_count
self.result_columns: list[ProfileColumnsResultColumn] = []

def create_column(self, column_name: str, column_type: str) -> ProfileColumnsResultColumn:
column = ProfileColumnsResultColumn(column_name, column_type)
def append_column(self, column: ProfileColumnsResultColumn) -> None:
self.result_columns.append(column)
return column

def get_cloud_dict(self) -> dict:
cloud_dict = {
Expand All @@ -103,9 +159,5 @@ def __init__(self, profile_columns_cfg: ProfileColumnsCfg):
self.profile_columns_cfg: ProfileColumnsCfg = profile_columns_cfg
self.tables: list[ProfileColumnsResultTable] = []

def create_table(
self, table_name: str, data_source_name: str, row_count: int | None = None
) -> ProfileColumnsResultTable:
table = ProfileColumnsResultTable(table_name, data_source_name, row_count)
def append_table(self, table: ProfileColumnsResultTable) -> None:
self.tables.append(table)
return table
87 changes: 87 additions & 0 deletions soda/core/soda/profiling/text_column_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from soda.execution.query.query import Query
from soda.profiling.profile_columns_result import ProfileColumnsResultColumn

if TYPE_CHECKING:
from soda.execution.data_source_scan import DataSourceScan
from soda.sodacl.data_source_check_cfg import ProfileColumnsCfg


class TextColumnProfiler:
def __init__(
self,
data_source_scan: DataSourceScan,
profile_columns_cfg: ProfileColumnsCfg,
table_name: str,
column_name: str,
column_data_type: str,
) -> None:
self.data_source_scan = data_source_scan
self.data_source = data_source_scan.data_source
self.logs = data_source_scan.scan._logs
self.profile_columns_cfg = profile_columns_cfg
self.table_name = table_name
self.column_name = column_name
self.column_data_type = column_data_type
self.result_column = ProfileColumnsResultColumn(column_name=column_name, column_data_type=column_data_type)

def profile(self) -> ProfileColumnsResultColumn:
self.logs.debug(f"Profiling column {self.column_name} of {self.table_name}")

# frequent values for text column
self._set_result_column_value_frequency_attribute()

# pure text aggregates
self._set_result_column_text_aggregation_attributes()

return self.result_column

def _set_result_column_value_frequency_attribute(self) -> None:
value_frequencies = self._compute_value_frequency()
if value_frequencies:
self.result_column.set_frequency_metric(value_frequencies)
else:
self.logs.warning(
f"Database returned no results for textual frequent values in {self.table_name}, column: {self.column_name}"
)

def _set_result_column_text_aggregation_attributes(self) -> None:
text_aggregates = self._compute_text_aggregates()
if text_aggregates:
self.result_column.set_text_aggregation_metrics(text_aggregates)
else:
self.logs.warning(
f"Database returned no results for textual aggregates in {self.table_name}, column: {self.column_name}"
)

def _compute_value_frequency(self) -> list[tuple] | None:
# frequent values for text column
value_frequencies_sql = self.data_source.profiling_sql_values_frequencies_query(
"text",
self.table_name,
self.column_name,
self.profile_columns_cfg.limit_mins_maxs,
self.profile_columns_cfg.limit_frequent_values,
)
value_frequencies_query = Query(
data_source_scan=self.data_source_scan,
unqualified_query_name=f"profiling-{self.table_name}-{self.column_name}-value-frequencies-text",
sql=value_frequencies_sql,
)
value_frequencies_query.execute()
frequency_rows = value_frequencies_query.rows
return frequency_rows

def _compute_text_aggregates(self) -> list[tuple] | None:
text_aggregates_sql = self.data_source.profiling_sql_aggregates_text(self.table_name, self.column_name)
text_aggregates_query = Query(
data_source_scan=self.data_source_scan,
unqualified_query_name=f"profiling: {self.table_name}, {self.column_name}: get textual aggregates",
sql=text_aggregates_sql,
)
text_aggregates_query.execute()
text_aggregates_rows = text_aggregates_query.rows
return text_aggregates_rows
Loading

0 comments on commit 13f3e20

Please sign in to comment.