Skip to content

Commit

Permalink
Merge pull request #298 from DataRecce/feature/drc-362-support-sqlmes…
Browse files Browse the repository at this point in the history
…h-lineage-diff-schema-diff

[Chore] Polish the SQLMesh Integration
  • Loading branch information
popcornylu authored May 2, 2024
2 parents 63a1c36 + ed4c938 commit 3ca6b06
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 60 deletions.
87 changes: 56 additions & 31 deletions js/src/components/env/EnvInfo.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,37 +136,62 @@ export function EnvInfo() {
</>
)}
<Divider />
<Flex justifyContent="left" gap="5px" direction="column">
<Heading size="sm">DBT</Heading>
<TableContainer>
<Table variant="simple">
<Thead>
<Tr>
<Th></Th>
<Th>current</Th>
<Th>base</Th>
</Tr>
</Thead>
<Tbody>
<Tr>
<Td>schema</Td>
<Td>{JSON.stringify(Array.from(currentSchemas))}</Td>
<Td>{JSON.stringify(Array.from(baseSchemas))}</Td>
</Tr>
<Tr>
<Td>version</Td>
<Td>{dbtCurrent?.dbt_version}</Td>
<Td>{dbtBase?.dbt_version}</Td>
</Tr>
<Tr>
<Td>timestamp</Td>
<Td>{currentTime}</Td>
<Td>{baseTime}</Td>
</Tr>
</Tbody>
</Table>
</TableContainer>
</Flex>
{envInfo?.adapterType === "dbt" && (
<Flex justifyContent="left" gap="5px" direction="column">
<Heading size="sm">DBT</Heading>
<TableContainer>
<Table variant="simple">
<Thead>
<Tr>
<Th></Th>
<Th>base</Th>
<Th>current</Th>
</Tr>
</Thead>
<Tbody>
<Tr>
<Td>schema</Td>
<Td>{JSON.stringify(Array.from(baseSchemas))}</Td>
<Td>{JSON.stringify(Array.from(currentSchemas))}</Td>
</Tr>
<Tr>
<Td>version</Td>
<Td>{dbtBase?.dbt_version}</Td>
<Td>{dbtCurrent?.dbt_version}</Td>
</Tr>
<Tr>
<Td>timestamp</Td>
<Td>{baseTime}</Td>
<Td>{currentTime}</Td>
</Tr>
</Tbody>
</Table>
</TableContainer>
</Flex>
)}
{envInfo?.adapterType === "sqlmesh" && (
<Flex justifyContent="left" gap="5px" direction="column">
<Heading size="sm">SQLMesh</Heading>
<TableContainer>
<Table variant="simple">
<Thead>
<Tr>
<Th></Th>
<Th>base</Th>
<Th>current</Th>
</Tr>
</Thead>
<Tbody>
<Tr>
<Td>Environment</Td>
<Td>{envInfo?.sqlmesh?.base_env}</Td>
<Td>{envInfo?.sqlmesh?.current_env}</Td>
</Tr>
</Tbody>
</Table>
</TableContainer>
</Flex>
)}
</Flex>
</ModalBody>
<ModalFooter>
Expand Down
14 changes: 12 additions & 2 deletions js/src/components/query/QueryPage.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import React, { useState, useCallback } from "react";
import { Box, Button, Flex } from "@chakra-ui/react";
import SqlEditor from "./SqlEditor";
import { useRecceQueryContext } from "@/lib/hooks/RecceQueryContext";
import {
defaultSqlQuery,
useRecceQueryContext,
} from "@/lib/hooks/RecceQueryContext";

import { createCheckByRun, updateCheck } from "@/lib/api/checks";
import { QueryDiffResultView } from "./QueryDiffResultView";
Expand All @@ -18,9 +21,16 @@ import { QueryResultView } from "./QueryResultView";
import { cancelRun, waitRun } from "@/lib/api/runs";
import { RunView } from "../run/RunView";
import { Run } from "@/lib/api/types";
import { useLineageGraphContext } from "@/lib/hooks/LineageGraphContext";

export const QueryPage = () => {
const { sqlQuery, setSqlQuery } = useRecceQueryContext();
const { sqlQuery: _sqlQuery, setSqlQuery } = useRecceQueryContext();
const { envInfo } = useLineageGraphContext();

let sqlQuery = _sqlQuery;
if (envInfo?.adapterType === "sqlmesh" && _sqlQuery === defaultSqlQuery) {
sqlQuery = `select * from db.mymodel`;
}

const [runType, setRunType] = useState<string>();
const [runId, setRunId] = useState<string>();
Expand Down
6 changes: 5 additions & 1 deletion js/src/lib/api/info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ export interface ManifestMetadata extends ArtifactMetadata {
project_name?: string;
user_id?: string;
}
export interface SQLMeshInfo {
base_env: string;
current_env: string;
}

export interface CatalogMetadata extends ArtifactMetadata {}

Expand All @@ -54,7 +58,6 @@ export interface LineageData {
manifest_metadata?: ManifestMetadata | null;
catalog_metadata?: CatalogMetadata | null;
}

interface LineageOutput {
error?: string;
data?: LineageData;
Expand Down Expand Up @@ -123,6 +126,7 @@ export interface ServerInfoResult {
review_mode: boolean;
git?: gitInfo;
pull_request?: pullRequestInfo;
sqlmesh?: SQLMeshInfo;
lineage: {
base: LineageData;
current: LineageData;
Expand Down
3 changes: 3 additions & 0 deletions js/src/lib/hooks/LineageGraphContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import React, {
import { cacheKeys } from "../api/cacheKeys";
import {
ManifestMetadata,
SQLMeshInfo,
getServerInfo,
gitInfo,
pullRequestInfo,
Expand All @@ -27,6 +28,7 @@ interface EnvInfo {
base: ManifestMetadata | undefined | null;
current: ManifestMetadata | undefined | null;
};
sqlmesh?: SQLMeshInfo | null;
}

export interface LineageGraphContextType {
Expand Down Expand Up @@ -137,6 +139,7 @@ export function LineageGraphContextProvider({ children }: LineageGraphProps) {
base: dbtBase,
current: dbtCurrent,
},
sqlmesh: data?.sqlmesh,
};

return (
Expand Down
35 changes: 19 additions & 16 deletions js/src/lib/hooks/RecceQueryContext.tsx
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
import React, { createContext, useContext } from 'react';
import React, { createContext, useContext } from "react";

export interface QueryContext {
sqlQuery: string;
setSqlQuery: (sqlQuery: string) => void;
sqlQuery: string;
setSqlQuery: (sqlQuery: string) => void;
}

const defaultSqlQuery = 'select * from {{ ref("mymodel") }}';
export const defaultSqlQuery = 'select * from {{ ref("mymodel") }}';

const defaultQueryContext: QueryContext = {
sqlQuery: defaultSqlQuery,
setSqlQuery: () => {},
sqlQuery: defaultSqlQuery,
setSqlQuery: () => {},
};

const RecceQueryContext = createContext(defaultQueryContext);

interface QueryContextProps {
children: React.ReactNode;
children: React.ReactNode;
}

export function RecceQueryContextProvider({ children }: QueryContextProps) {
const [sqlQuery, setSqlQuery] = React.useState<string>(defaultSqlQuery);
return (
<RecceQueryContext.Provider value={{ setSqlQuery, sqlQuery }}>
{children}
</RecceQueryContext.Provider>
);
const [sqlQuery, setSqlQuery] = React.useState<string>(defaultSqlQuery);
return (
<RecceQueryContext.Provider value={{ setSqlQuery, sqlQuery }}>
{children}
</RecceQueryContext.Provider>
);
}

export const useRecceQueryContext = () => useContext(RecceQueryContext);


export interface RowCountStateContext {
isNodesFetching: string[];
setIsNodesFetching: (nodes: string[]) => void;
Expand All @@ -46,10 +45,14 @@ interface RowCountStateContextProps {
children: React.ReactNode;
}

export function RowCountStateContextProvider({ children }: RowCountStateContextProps) {
export function RowCountStateContextProvider({
children,
}: RowCountStateContextProps) {
const [isNodesFetching, setIsNodesFetching] = React.useState<string[]>([]);
return (
<RowCountStateContext.Provider value={{ isNodesFetching , setIsNodesFetching }}>
<RowCountStateContext.Provider
value={{ isNodesFetching, setIsNodesFetching }}
>
{children}
</RowCountStateContext.Provider>
);
Expand Down
20 changes: 14 additions & 6 deletions recce/adapter/sqlmesh_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

import pandas as pd
from sqlglot import parse_one, Expression
from sqlglot import parse_one, Expression, select
from sqlglot.optimizer import traverse_scope
from sqlmesh.core.context import Context as SqlmeshContext
from sqlmesh.core.environment import Environment
Expand Down Expand Up @@ -90,18 +90,26 @@ def fetchdf_with_limit(
limit: Optional[int] = None
) -> (pd.DataFrame, bool):
if isinstance(sql, str):
expression = parse_one(sql)
expression = parse_one(sql, dialect=self.context.default_dialect)
else:
expression = sql

expression = expression.limit(limit + 1 if limit else None)

env = self.base_env if base else self.curr_env
model_names = [model.name for model in self.context.models.values()]
if env.name != 'prod':
for scope in traverse_scope(expression):
for table in scope.tables:
table.args['db'] = f"{table.args['db']}__{env.name}"

if f'{table.db}.{table.name}' in model_names:
table.args['db'] = f"{table.args['db']}__{env.name}"

if limit:
expression = select(
'*'
).from_(
'__QUERY'
).with_(
'__QUERY', as_=expression
).limit(limit + 1 if limit else None)
df = self.context.fetchdf(expression)
if limit and len(df) > limit:
df = df.head(limit)
Expand Down
12 changes: 11 additions & 1 deletion recce/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def get_info():
state = context.export_state()

try:
return {
info = {
'adapter_type': context.adapter_type,
'review_mode': context.review_mode,
'git': state.git.to_dict() if state.git else None,
Expand All @@ -156,6 +156,16 @@ async def get_info():
},
'demo': bool(demo)
}

if context.adapter_type == 'sqlmesh':
from recce.adapter.sqlmesh_adapter import SqlmeshAdapter
sqlmesh_adapter: SqlmeshAdapter = context.adapter
info['sqlmesh'] = {
'base_env': sqlmesh_adapter.base_env.name,
'current_env': sqlmesh_adapter.curr_env.name,
}

return info
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

Expand Down
4 changes: 3 additions & 1 deletion recce/tasks/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ class DataFrame(BaseModel):

@staticmethod
def from_agate(table: 'agate.Table', limit: t.Optional[int] = None, more: t.Optional[bool] = None):
import dbt.clients.agate_helper
import agate
columns = []

for col_name, col_type in zip(table.column_names, table.column_types):
import dbt.clients.agate_helper

has_integer = hasattr(dbt.clients.agate_helper, 'Integer')

if isinstance(col_type, agate.Number):
Expand Down
4 changes: 2 additions & 2 deletions recce/tasks/rowcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def execute_sqlmesh(self):
curr_row_count = None

try:
df = sqlmesh_adapter.fetchdf(f'select count(*) from {name}')
df, _ = sqlmesh_adapter.fetchdf_with_limit(f'select count(*) from {name}', base=True)
base_row_count = int(df.iloc[0, 0])
except Exception:
pass
self.check_cancel()

try:
df = sqlmesh_adapter.fetchdf(f'select count(*) from {name}', env='dev')
df, _ = sqlmesh_adapter.fetchdf_with_limit(f'select count(*) from {name}', base=False)
curr_row_count = int(df.iloc[0, 0])
except Exception:
pass
Expand Down

0 comments on commit 3ca6b06

Please sign in to comment.