Skip to content

Commit

Permalink
Add /api/compare_column_values
Browse files Browse the repository at this point in the history
Signed-off-by: Ching Yi, Chan <[email protected]>
  • Loading branch information
qrtt1 committed Dec 13, 2023
1 parent 2c89b1f commit c0c700a
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ dbt.log
.coverage
build
.tox
.noai
17 changes: 17 additions & 0 deletions recce/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
from typing import List

import click
Expand Down Expand Up @@ -118,5 +119,21 @@ def server(host, port, **kwargs):
uvicorn.run(app, host=host, port=port, lifespan='on')


@cli.command(cls=TrackCommand)
@click.argument("primary_key", type=str, required=True)
@click.argument("model", type=str, required=True)
def compare_all_columns(primary_key: str, model: str):
"""
Compare all columns for a specified model.
"""
try:
dbt_context = DBTContext.load()
df = dbt_context.compare_all_columns(primary_key, model)
print(df.to_string(max_rows=None, max_colwidth=None))
except:
traceback.print_exc()
pass


if __name__ == "__main__":
cli()
56 changes: 53 additions & 3 deletions recce/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@
import os
import time
from dataclasses import dataclass, fields
from typing import Dict, List, Optional, Union, Callable
from typing import Callable, Dict, List, Optional, Union

import agate
import pandas as pd
from dbt.adapters.factory import get_adapter_by_type
from dbt.adapters.sql import SQLAdapter
from dbt.cli.main import dbtRunner
from dbt.config.profile import Profile
from dbt.config.project import Project
from dbt.config.project import Project, package_config_from_data
from dbt.config.renderer import PackageRenderer
from dbt.config.runtime import load_profile, load_project
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import Manifest, WritableManifest
from dbt.contracts.graph.model_config import ContractConfig, NodeConfig
from dbt.contracts.graph.nodes import Contract, DependsOn, ManifestNode, ModelNode, ResultNode, SourceDefinition
from dbt.contracts.graph.unparsed import Docs
from dbt.contracts.results import CatalogArtifact
from dbt.deps.base import downloads_directory
from dbt.deps.resolver import resolve_packages
from dbt.node_types import AccessType, ModelLanguage, NodeType
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
Expand Down Expand Up @@ -194,8 +197,25 @@ class DBTContext:
artifacts_files = []

@classmethod
def load(cls, **kwargs):
def packages_downloader(cls, project: Project):
# reference from dbt-core tasks/deps.py

os.environ["DBT_MACRO_DEBUGGING"] = "false"
os.environ["DBT_VERSION_CHECK"] = "false"

renderer = PackageRenderer({})
packages_lock_dict = {'packages': [{'package': 'dbt-labs/audit_helper', 'version': '0.9.0'}]}
packages_lock_config = package_config_from_data(
renderer.render_data(packages_lock_dict), packages_lock_dict
).packages

lock_defined_deps = resolve_packages(packages_lock_config, project, {})
with downloads_directory():
for package in lock_defined_deps:
package.install(project, renderer)

@classmethod
def load(cls, **kwargs):
# We need to run the dbt parse command because
# 1. load the dbt profiles by dbt-core rule
# 2. initialize the adapter
Expand Down Expand Up @@ -226,6 +246,11 @@ def load(cls, **kwargs):
profile = load_profile(project_path, {}, profile_name_override=profile_name, target_override=target)
project = load_project(project_path, False, profile)

packages = [x.package for x in project.packages.packages]
if not kwargs.get('skip_download', False) and 'dbt-labs/audit_helper' not in packages:
cls.packages_downloader(project)
return cls.load(**dict(skip_download=True, **kwargs))

adapter: SQLAdapter = get_adapter_by_type(profile.credentials.type)

dbt_context = cls(profile=profile,
Expand Down Expand Up @@ -414,3 +439,28 @@ def refresh(self, refresh_file_path: str = None):
self.base_manifest = load_manifest(refresh_file_path)
elif refresh_file_path.endswith('catalog.json'):
self.base_catalog = load_catalog(refresh_file_path)

def compare_all_columns(self, primary_key: str, model: str):
# find the relation names
base_relation = self.generate_sql(f'{{{{ ref("{model}") }}}}', True)

# https://github.com/dbt-labs/dbt-audit-helper/tree/0.9.0/#compare_column_values-source
# {{
# audit_helper.compare_all_columns(
# a_relation=ref('stg_customers'),
# b_relation=api.Relation.create(database='dbt_db', schema='analytics_prod', identifier='stg_customers'),
# exclude_columns=['updated_at'],
# primary_key='id'
# )
# }}

sql_template = f"""
{{{{
audit_helper.compare_all_columns(
a_relation=ref('{model}'),
b_relation='{base_relation}',
primary_key='{primary_key}'
)
}}}}
"""
return self.execute_sql(sql_template, False)
23 changes: 23 additions & 0 deletions recce/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,29 @@ class QueryInput(BaseModel):
sql_template: str


class CompareColumnValuesInput(BaseModel):
model: str
primary_key: str
exclude_columns: Optional[list]


@app.post("/api/compare_column_values")
async def compare_column_values(input: CompareColumnValuesInput):
try:
# TODO support exclude columns
print(input)
result = dbt_context.compare_all_columns(input.primary_key, input.model)
result_json = result.to_json(orient='table')

import json
return json.loads(result_json)

except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=400, detail=str(e))


@app.post("/api/query")
async def query(input: QueryInput):
from jinja2.exceptions import TemplateSyntaxError
Expand Down

0 comments on commit c0c700a

Please sign in to comment.