diff --git a/mypy-baseline.txt b/mypy-baseline.txt index d64109444493e..0fb6ecf8326f6 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -50,9 +50,11 @@ posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Argume posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Incompatible types in assignment (expression has type "list[str] | None", variable has type "list[str]") [assignment] posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Argument 1 to "setup_incremental_object" has incompatible type "dict[str, ResolveParamConfig | IncrementalParamConfig | Any] | None"; expected "dict[str, Any]" [arg-type] posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Argument "base_url" to "RESTClient" has incompatible type "str | None"; expected "str" [arg-type] +posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Argument 2 to "convert_types" has incompatible type "dict[str, TColumnSchema] | Sequence[TColumnSchema] | BaseModel | type[BaseModel] | Callable[[Any], dict[str, TColumnSchema] | Sequence[TColumnSchema] | BaseModel | type[BaseModel]] | None"; expected "dict[str, dict[str, Any]] | None" [arg-type] posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Argument 1 to "exclude_keys" has incompatible type "dict[str, ResolveParamConfig | IncrementalParamConfig | Any] | None"; expected "Mapping[str, Any]" [arg-type] posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Incompatible default for argument "resolved_param" (default has type "ResolvedParam | None", argument has type "ResolvedParam") [assignment] posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Unused "type: ignore" comment [unused-ignore] +posthog/temporal/data_imports/pipelines/rest_source/__init__.py:0: error: Argument 2 to "convert_types" has incompatible type "dict[str, TColumnSchema] | Sequence[TColumnSchema] | BaseModel | type[BaseModel] | Callable[[Any], dict[str, TColumnSchema] | Sequence[TColumnSchema] | BaseModel | type[BaseModel]] | None"; expected "dict[str, dict[str, Any]] | None" [arg-type] posthog/temporal/data_imports/pipelines/vitally/__init__.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/temporal/data_imports/pipelines/vitally/__init__.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/temporal/data_imports/pipelines/vitally/__init__.py:0: error: Unused "type: ignore" comment [unused-ignore] @@ -597,6 +599,10 @@ posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: d posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: def get(self, Type, Sequence[str], /) -> Sequence[str] posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: note: def [_T] get(self, Type, _T, /) -> Sequence[str] | _T posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py:0: error: Argument "source_id" to "sync_old_schemas_with_new_schemas" has incompatible type "str"; expected "UUID" [arg-type] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "DataWarehouseCredential | Combinable | None") [assignment] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "str | int | Combinable") [assignment] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "dict[str, dict[str, str | bool]] | dict[str, str]", variable has type "dict[str, dict[str, str]]") [assignment] +posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Item "None" of "dict[str, str] | None" has no attribute "get" [union-attr] posthog/taxonomy/property_definition_api.py:0: error: Item "AnonymousUser" of "User | AnonymousUser" has no attribute "organization" [union-attr] posthog/taxonomy/property_definition_api.py:0: error: Item "None" of "Organization | Any | None" has no attribute "is_feature_available" [union-attr] posthog/taxonomy/property_definition_api.py:0: error: Item "ForeignObjectRel" of "Field[Any, Any] | ForeignObjectRel | GenericForeignKey" has no attribute "cached_col" [union-attr] @@ -754,13 +760,6 @@ posthog/temporal/tests/batch_exports/test_batch_exports.py:0: error: TypedDict k posthog/temporal/data_modeling/run_workflow.py:0: error: Dict entry 20 has incompatible type "str": "Literal['complex']"; expected "str": "Literal['text', 'double', 'bool', 'timestamp', 'bigint', 'binary', 'json', 'decimal', 'wei', 'date', 'time']" [dict-item] posthog/temporal/data_modeling/run_workflow.py:0: error: Dict entry 21 has incompatible type "str": "Literal['complex']"; expected "str": "Literal['text', 'double', 'bool', 'timestamp', 'bigint', 'binary', 'json', 'decimal', 'wei', 'date', 'time']" [dict-item] posthog/temporal/data_modeling/run_workflow.py:0: error: Dict entry 22 has incompatible type "str": "Literal['complex']"; expected "str": "Literal['text', 'double', 'bool', 'timestamp', 'bigint', 'binary', 'json', 'decimal', 'wei', 'date', 'time']" [dict-item] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: "FilesystemDestinationClientConfiguration" has no attribute "delta_jobs_per_write" [attr-defined] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: "type[FilesystemDestinationClientConfiguration]" has no attribute "delta_jobs_per_write" [attr-defined] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "DataWarehouseCredential | Combinable | None") [assignment] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Incompatible types in assignment (expression has type "object", variable has type "str | int | Combinable") [assignment] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Right operand of "and" is never evaluated [unreachable] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Statement is unreachable [unreachable] -posthog/temporal/data_imports/pipelines/pipeline_sync.py:0: error: Name "raw_db_columns" already defined on line 0 [no-redef] posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] @@ -790,8 +789,13 @@ posthog/api/plugin_log_entry.py:0: error: Name "timezone.datetime" is not define posthog/api/plugin_log_entry.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] posthog/api/plugin_log_entry.py:0: error: Name "timezone.datetime" is not defined [name-defined] posthog/api/plugin_log_entry.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] -posthog/api/sharing.py:0: error: Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr] posthog/temporal/data_imports/external_data_job.py:0: error: Argument "status" to "update_external_job_status" has incompatible type "str"; expected "Status" [arg-type] +posthog/api/sharing.py:0: error: Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable) [union-attr] +posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] +posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] +posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] +posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] +posthog/temporal/tests/data_imports/test_end_to_end.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_execute_calls" (hint: "_execute_calls: list[] = ...") [var-annotated] posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_execute_async_calls" (hint: "_execute_async_calls: list[] = ...") [var-annotated] posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_cursors" (hint: "_cursors: list[] = ...") [var-annotated] @@ -808,11 +812,6 @@ posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "s posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] -posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] -posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] -posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] -posthog/temporal/tests/external_data/test_external_data_job.py:0: error: Invalid index type "str" for "dict[Type, Sequence[str]]"; expected type "Type" [index] -posthog/temporal/tests/data_imports/test_end_to_end.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py:0: error: Incompatible types in assignment (expression has type "str | int", variable has type "int") [assignment] posthog/api/test/batch_exports/conftest.py:0: error: Signature of "run" incompatible with supertype "Worker" [override] posthog/api/test/batch_exports/conftest.py:0: note: Superclass: diff --git a/posthog/constants.py b/posthog/constants.py index 73a46e10f2476..bf525632bd148 100644 --- a/posthog/constants.py +++ b/posthog/constants.py @@ -303,7 +303,6 @@ class FlagRequestType(StrEnum): ENRICHED_DASHBOARD_INSIGHT_IDENTIFIER = "Feature Viewed" DATA_WAREHOUSE_TASK_QUEUE = "data-warehouse-task-queue" -DATA_WAREHOUSE_TASK_QUEUE_V2 = "v2-data-warehouse-task-queue" BATCH_EXPORTS_TASK_QUEUE = "batch-exports-task-queue" SYNC_BATCH_EXPORTS_TASK_QUEUE = "no-sandbox-python-django" GENERAL_PURPOSE_TASK_QUEUE = "general-purpose-task-queue" diff --git a/posthog/hogql/database/s3_table.py b/posthog/hogql/database/s3_table.py index 479969ae93bd1..1dfee7dbc65c7 100644 --- a/posthog/hogql/database/s3_table.py +++ b/posthog/hogql/database/s3_table.py @@ -1,5 +1,5 @@ import re -from typing import TYPE_CHECKING, Optional +from typing import Optional from posthog.clickhouse.client.escape import substitute_params from posthog.hogql.context import HogQLContext @@ -7,9 +7,6 @@ from posthog.hogql.errors import ExposedHogQLError from posthog.hogql.escape_sql import escape_hogql_identifier -if TYPE_CHECKING: - from posthog.warehouse.models import ExternalDataJob - def build_function_call( url: str, @@ -18,10 +15,7 @@ def build_function_call( access_secret: Optional[str] = None, structure: Optional[str] = None, context: Optional[HogQLContext] = None, - pipeline_version: Optional["ExternalDataJob.PipelineVersion"] = None, ) -> str: - from posthog.warehouse.models import ExternalDataJob - raw_params: dict[str, str] = {} def add_param(value: str, is_sensitive: bool = True) -> str: @@ -42,18 +36,10 @@ def return_expr(expr: str) -> str: # DeltaS3Wrapper format if format == "DeltaS3Wrapper": - query_folder = "__query_v2" if pipeline_version == ExternalDataJob.PipelineVersion.V2 else "__query" - if url.endswith("/"): - if pipeline_version == ExternalDataJob.PipelineVersion.V2: - escaped_url = add_param(f"{url[:-5]}{query_folder}/*.parquet") - else: - escaped_url = add_param(f"{url[:-1]}{query_folder}/*.parquet") + escaped_url = add_param(f"{url[:-1]}__query/*.parquet") else: - if pipeline_version == ExternalDataJob.PipelineVersion.V2: - escaped_url = add_param(f"{url[:-4]}{query_folder}/*.parquet") - else: - escaped_url = add_param(f"{url}{query_folder}/*.parquet") + escaped_url = add_param(f"{url}__query/*.parquet") if structure: escaped_structure = add_param(structure, False) diff --git a/posthog/management/commands/start_temporal_worker.py b/posthog/management/commands/start_temporal_worker.py index 7984df0672471..cf0ccdf2cc88e 100644 --- a/posthog/management/commands/start_temporal_worker.py +++ b/posthog/management/commands/start_temporal_worker.py @@ -11,7 +11,6 @@ from posthog.constants import ( BATCH_EXPORTS_TASK_QUEUE, DATA_WAREHOUSE_TASK_QUEUE, - DATA_WAREHOUSE_TASK_QUEUE_V2, GENERAL_PURPOSE_TASK_QUEUE, SYNC_BATCH_EXPORTS_TASK_QUEUE, ) @@ -32,14 +31,12 @@ SYNC_BATCH_EXPORTS_TASK_QUEUE: BATCH_EXPORTS_WORKFLOWS, BATCH_EXPORTS_TASK_QUEUE: BATCH_EXPORTS_WORKFLOWS, DATA_WAREHOUSE_TASK_QUEUE: DATA_SYNC_WORKFLOWS + DATA_MODELING_WORKFLOWS, - DATA_WAREHOUSE_TASK_QUEUE_V2: DATA_SYNC_WORKFLOWS + DATA_MODELING_WORKFLOWS, GENERAL_PURPOSE_TASK_QUEUE: PROXY_SERVICE_WORKFLOWS + DELETE_PERSONS_WORKFLOWS, } ACTIVITIES_DICT = { SYNC_BATCH_EXPORTS_TASK_QUEUE: BATCH_EXPORTS_ACTIVITIES, BATCH_EXPORTS_TASK_QUEUE: BATCH_EXPORTS_ACTIVITIES, DATA_WAREHOUSE_TASK_QUEUE: DATA_SYNC_ACTIVITIES + DATA_MODELING_ACTIVITIES, - DATA_WAREHOUSE_TASK_QUEUE_V2: DATA_SYNC_ACTIVITIES + DATA_MODELING_ACTIVITIES, GENERAL_PURPOSE_TASK_QUEUE: PROXY_SERVICE_ACTIVITIES + DELETE_PERSONS_ACTIVITIES, } diff --git a/posthog/temporal/data_imports/__init__.py b/posthog/temporal/data_imports/__init__.py index aab0a74ac554c..c59f20b05d8cf 100644 --- a/posthog/temporal/data_imports/__init__.py +++ b/posthog/temporal/data_imports/__init__.py @@ -6,7 +6,6 @@ update_external_data_job_model, check_billing_limits_activity, sync_new_schemas_activity, - trigger_pipeline_v2, ) WORKFLOWS = [ExternalDataJobWorkflow] @@ -18,5 +17,4 @@ create_source_templates, check_billing_limits_activity, sync_new_schemas_activity, - trigger_pipeline_v2, ] diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 45f482811f754..beaf47836491b 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -1,26 +1,16 @@ -import asyncio import dataclasses import datetime as dt import json import re -import threading -import time -from django.conf import settings from django.db import close_old_connections import posthoganalytics -import psutil from temporalio import activity, exceptions, workflow from temporalio.common import RetryPolicy -from temporalio.exceptions import WorkflowAlreadyStartedError -from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE_V2 - # TODO: remove dependency -from posthog.settings.base_variables import TEST from posthog.temporal.batch_exports.base import PostHogWorkflow -from posthog.temporal.common.client import sync_connect from posthog.temporal.data_imports.workflow_activities.check_billing_limits import ( CheckBillingLimitsActivityInputs, check_billing_limits_activity, @@ -144,32 +134,6 @@ def update_external_data_job_model(inputs: UpdateExternalDataJobStatusInputs) -> ) -@activity.defn -def trigger_pipeline_v2(inputs: ExternalDataWorkflowInputs): - logger = bind_temporal_worker_logger_sync(team_id=inputs.team_id) - logger.debug("Triggering V2 pipeline") - - temporal = sync_connect() - try: - asyncio.run( - temporal.start_workflow( - workflow="external-data-job", - arg=dataclasses.asdict(inputs), - id=f"{inputs.external_data_schema_id}-V2", - task_queue=str(DATA_WAREHOUSE_TASK_QUEUE_V2), - retry_policy=RetryPolicy( - maximum_interval=dt.timedelta(seconds=60), - maximum_attempts=1, - non_retryable_error_types=["NondeterminismError"], - ), - ) - ) - except WorkflowAlreadyStartedError: - pass - - logger.debug("V2 pipeline triggered") - - @dataclasses.dataclass class CreateSourceTemplateInputs: team_id: int @@ -181,22 +145,6 @@ def create_source_templates(inputs: CreateSourceTemplateInputs) -> None: create_warehouse_templates_for_source(team_id=inputs.team_id, run_id=inputs.run_id) -def log_memory_usage(): - process = psutil.Process() - logger = bind_temporal_worker_logger_sync(team_id=0) - - while True: - memory_info = process.memory_info() - logger.info(f"Memory Usage: RSS = {memory_info.rss / (1024 * 1024):.2f} MB") - - time.sleep(10) # Log every 10 seconds - - -if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2: - thread = threading.Thread(target=log_memory_usage, daemon=True) - thread.start() - - # TODO: update retry policies @workflow.defn(name="external-data-job") class ExternalDataJobWorkflow(PostHogWorkflow): @@ -209,14 +157,6 @@ def parse_inputs(inputs: list[str]) -> ExternalDataWorkflowInputs: async def run(self, inputs: ExternalDataWorkflowInputs): assert inputs.external_data_schema_id is not None - if settings.TEMPORAL_TASK_QUEUE != DATA_WAREHOUSE_TASK_QUEUE_V2 and not TEST: - await workflow.execute_activity( - trigger_pipeline_v2, - inputs, - start_to_close_timeout=dt.timedelta(minutes=1), - retry_policy=RetryPolicy(maximum_attempts=1), - ) - update_inputs = UpdateExternalDataJobStatusInputs( job_id=None, status=ExternalDataJob.Status.COMPLETED, diff --git a/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py b/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py index 0e98f948f9519..88ef3e0e69645 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py +++ b/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py @@ -46,8 +46,7 @@ def _get_credentials(self): def _get_delta_table_uri(self) -> str: normalized_resource_name = NamingConvention().normalize_identifier(self._resource_name) - # Appended __v2 on to the end of the url so that data of the V2 pipeline isn't the same as V1 - return f"{settings.BUCKET_URL}/{self._job.folder_path()}/{normalized_resource_name}__v2" + return f"{settings.BUCKET_URL}/{self._job.folder_path()}/{normalized_resource_name}" def _evolve_delta_schema(self, schema: pa.Schema) -> deltalake.DeltaTable: delta_table = self.get_delta_table() diff --git a/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py b/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py index c1a8b95bb0abe..e3356296a130d 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py +++ b/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py @@ -8,11 +8,13 @@ import deltalake as deltalake from posthog.temporal.common.logger import FilteringBoundLogger from posthog.temporal.data_imports.pipelines.pipeline.utils import ( + _handle_null_columns_with_definitions, _update_incremental_state, _get_primary_keys, _evolve_pyarrow_schema, _append_debug_column_to_pyarrows_table, _update_job_row_count, + _update_last_synced_at_sync, table_from_py_list, ) from posthog.temporal.data_imports.pipelines.pipeline.delta_table_helper import DeltaTableHelper @@ -133,6 +135,7 @@ def _process_pa_table(self, pa_table: pa.Table, index: int): pa_table = _append_debug_column_to_pyarrows_table(pa_table, self._load_id) pa_table = _evolve_pyarrow_schema(pa_table, delta_table.schema() if delta_table is not None else None) + pa_table = _handle_null_columns_with_definitions(pa_table, self._resource) table_primary_keys = _get_primary_keys(self._resource) delta_table = self._delta_table_helper.write_to_deltalake( @@ -173,11 +176,14 @@ def _post_run_operations(self, row_count: int): process.kill() file_uris = delta_table.file_uris() - self._logger.info(f"Preparing S3 files - total parquet files: {len(file_uris)}") + self._logger.debug(f"Preparing S3 files - total parquet files: {len(file_uris)}") prepare_s3_files_for_querying( self._job.folder_path(), self._resource_name, file_uris, ExternalDataJob.PipelineVersion.V2 ) + self._logger.debug("Updating last synced at timestamp on schema") + _update_last_synced_at_sync(self._schema, self._job) + self._logger.debug("Validating schema and updating table") validate_schema_and_update_table_sync( diff --git a/posthog/temporal/data_imports/pipelines/pipeline/utils.py b/posthog/temporal/data_imports/pipelines/pipeline/utils.py index d0f3816c5fcd6..be0bb1064b5d2 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline/utils.py +++ b/posthog/temporal/data_imports/pipelines/pipeline/utils.py @@ -8,8 +8,19 @@ import deltalake as deltalake from django.db.models import F from posthog.temporal.common.logger import FilteringBoundLogger +from dlt.common.data_types.typing import TDataType from posthog.warehouse.models import ExternalDataJob, ExternalDataSchema +DLT_TO_PA_TYPE_MAP = { + "text": pa.string(), + "bigint": pa.int64(), + "bool": pa.bool_(), + "timestamp": pa.timestamp("us"), + "json": pa.string(), + "double": pa.float64(), + "date": pa.date64(), +} + def _get_primary_keys(resource: DltResource) -> list[Any] | None: primary_keys = resource._hints.get("primary_key") @@ -29,6 +40,30 @@ def _get_primary_keys(resource: DltResource) -> list[Any] | None: raise Exception(f"primary_keys of type {primary_keys.__class__.__name__} are not supported") +def _get_column_hints(resource: DltResource) -> dict[str, TDataType] | None: + columns = resource._hints.get("columns") + + if columns is None: + return None + + return {key: value.get("data_type") for key, value in columns.items()} # type: ignore + + +def _handle_null_columns_with_definitions(table: pa.Table, resource: DltResource) -> pa.Table: + column_hints = _get_column_hints(resource) + + if column_hints is None: + return table + + for field_name, data_type in column_hints.items(): + # If the table doesn't have all fields, then add a field with all Nulls and the correct field type + if field_name not in table.schema.names: + new_column = pa.array([None] * table.num_rows, type=DLT_TO_PA_TYPE_MAP[data_type]) + table = table.append_column(field_name, new_column) + + return table + + def _evolve_pyarrow_schema(table: pa.Table, delta_schema: deltalake.Schema | None) -> pa.Table: py_table_field_names = table.schema.names @@ -128,6 +163,11 @@ def _update_incremental_state(schema: ExternalDataSchema | None, table: pa.Table schema.update_incremental_field_last_value(last_value) +def _update_last_synced_at_sync(schema: ExternalDataSchema, job: ExternalDataJob) -> None: + schema.last_synced_at = job.created_at + schema.save() + + def _update_job_row_count(job_id: str, count: int, logger: FilteringBoundLogger) -> None: logger.debug(f"Updating rows_synced with +{count}") ExternalDataJob.objects.filter(id=job_id).update(rows_synced=F("rows_synced") + count) diff --git a/posthog/temporal/data_imports/pipelines/pipeline_sync.py b/posthog/temporal/data_imports/pipelines/pipeline_sync.py index 314dad8a436e1..4df65ea0ae2e9 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline_sync.py +++ b/posthog/temporal/data_imports/pipelines/pipeline_sync.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from datetime import datetime, date -from typing import Any, Literal, Optional -from collections.abc import Iterator, Sequence +from typing import Any, Optional import uuid import dlt @@ -13,45 +12,19 @@ import dlt.extract import dlt.extract.incremental import dlt.extract.incremental.transform -from dlt.pipeline.exceptions import PipelineStepFailed -from deltalake import DeltaTable import pendulum import pyarrow -from posthog.settings.base_variables import TEST -from structlog.typing import FilteringBoundLogger -from dlt.common.libs.deltalake import get_delta_tables from dlt.common.normalizers.naming.snake_case import NamingConvention from dlt.common.schema.typing import TSchemaTables -from dlt.load.exceptions import LoadClientJobRetry -from dlt.sources import DltSource -from dlt.destinations.impl.filesystem.filesystem import FilesystemClient -from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration -from dlt.common.destination.reference import ( - FollowupJobRequest, -) -from dlt.common.destination.typing import ( - PreparedTableSchema, -) -from dlt.destinations.job_impl import ( - ReferenceFollowupJobRequest, -) -from dlt.common.storages import FileStorage -from dlt.common.storages.load_package import ( - LoadJobInfo, -) -from deltalake.exceptions import DeltaError -from collections import Counter from clickhouse_driver.errors import ServerException from posthog.temporal.common.logger import bind_temporal_worker_logger_sync -from posthog.warehouse.data_load.validate_schema import dlt_to_hogql_type from posthog.warehouse.models.credential import get_or_create_datawarehouse_credential from posthog.warehouse.models.external_data_job import ExternalDataJob from posthog.warehouse.models.external_data_schema import ExternalDataSchema from posthog.warehouse.models.external_data_source import ExternalDataSource from posthog.warehouse.models.table import DataWarehouseTable -from posthog.temporal.data_imports.util import prepare_s3_files_for_querying def _from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any: @@ -79,340 +52,6 @@ class PipelineInputs: team_id: int -class DataImportPipelineSync: - loader_file_format: Literal["parquet"] = "parquet" - - def __init__( - self, - inputs: PipelineInputs, - source: DltSource, - logger: FilteringBoundLogger, - reset_pipeline: bool, - incremental: bool = False, - ): - self.inputs = inputs - self.logger = logger - - self._incremental = incremental - self.refresh_dlt = reset_pipeline - self.should_chunk_pipeline = ( - incremental - and inputs.job_type != ExternalDataSource.Type.POSTGRES - and inputs.job_type != ExternalDataSource.Type.MYSQL - and inputs.job_type != ExternalDataSource.Type.MSSQL - and inputs.job_type != ExternalDataSource.Type.SNOWFLAKE - and inputs.job_type != ExternalDataSource.Type.BIGQUERY - ) - - if self.should_chunk_pipeline: - # Incremental syncs: Assuming each page is 100 items for now so bound each run at 50_000 items - self.source = source.add_limit(500) - else: - self.source = source - - def _get_pipeline_name(self): - return f"{self.inputs.job_type}_pipeline_{self.inputs.team_id}_run_{self.inputs.schema_id}" - - def _get_credentials(self): - if TEST: - return { - "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, - "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, - "endpoint_url": settings.OBJECT_STORAGE_ENDPOINT, - "region_name": settings.AIRBYTE_BUCKET_REGION, - "AWS_ALLOW_HTTP": "true", - "AWS_S3_ALLOW_UNSAFE_RENAME": "true", - } - - return { - "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, - "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, - "region_name": settings.AIRBYTE_BUCKET_REGION, - "AWS_DEFAULT_REGION": settings.AIRBYTE_BUCKET_REGION, - "AWS_S3_ALLOW_UNSAFE_RENAME": "true", - } - - def _get_destination(self): - return dlt.destinations.filesystem( - credentials=self._get_credentials(), - bucket_url=settings.BUCKET_URL, # type: ignore - ) - - def _create_pipeline(self): - pipeline_name = self._get_pipeline_name() - destination = self._get_destination() - - def create_table_chain_completed_followup_jobs( - self: FilesystemClient, - table_chain: Sequence[PreparedTableSchema], - completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> list[FollowupJobRequest]: - assert completed_table_chain_jobs is not None - jobs = super(FilesystemClient, self).create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) - if table_chain[0].get("table_format") == "delta": - for table in table_chain: - table_job_paths = [ - job.file_path - for job in completed_table_chain_jobs - if job.job_file_info.table_name == table["name"] - ] - if len(table_job_paths) == 0: - # file_name = ParsedLoadJobFileName(table["name"], "empty", 0, "reference").file_name() - # TODO: if we implement removal od orphaned rows, we may need to propagate such job without files - # to the delta load job - pass - else: - files_per_job = self.config.delta_jobs_per_write or len(table_job_paths) - for i in range(0, len(table_job_paths), files_per_job): - jobs_chunk = table_job_paths[i : i + files_per_job] - file_name = FileStorage.get_file_name_from_file_path(jobs_chunk[0]) - jobs.append(ReferenceFollowupJobRequest(file_name, jobs_chunk)) - - return jobs - - def _iter_chunks(self, lst: list[Any], n: int) -> Iterator[list[Any]]: - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - # Monkey patch to fix large memory consumption until https://github.com/dlt-hub/dlt/pull/2031 gets merged in - FilesystemDestinationClientConfiguration.delta_jobs_per_write = 1 - FilesystemClient.create_table_chain_completed_followup_jobs = create_table_chain_completed_followup_jobs # type: ignore - FilesystemClient._iter_chunks = _iter_chunks # type: ignore - - dlt.config["data_writer.file_max_items"] = 500_000 - dlt.config["data_writer.file_max_bytes"] = 500_000_000 # 500 MB - dlt.config["parallelism_strategy"] = "table-sequential" - dlt.config["delta_jobs_per_write"] = 1 - - dlt.config["normalize.parquet_normalizer.add_dlt_load_id"] = True - dlt.config["normalize.parquet_normalizer.add_dlt_id"] = True - - return dlt.pipeline( - pipeline_name=pipeline_name, destination=destination, dataset_name=self.inputs.dataset_name, progress="log" - ) - - def _prepare_s3_files_for_querying(self, file_uris: list[str]): - job = ExternalDataJob.objects.prefetch_related( - "pipeline", Prefetch("schema", queryset=ExternalDataSchema.objects.prefetch_related("source")) - ).get(pk=self.inputs.run_id) - - schema = ( - ExternalDataSchema.objects.prefetch_related("source") - .exclude(deleted=True) - .get(id=self.inputs.schema_id, team_id=self.inputs.team_id) - ) - - prepare_s3_files_for_querying(job.folder_path(), schema.name, file_uris) - - def _get_delta_table(self, resouce_name: str) -> DeltaTable | None: - normalized_schema_name = NamingConvention().normalize_identifier(resouce_name) - delta_uri = f"{settings.BUCKET_URL}/{self.inputs.dataset_name}/{normalized_schema_name}" - storage_options = self._get_credentials() - - self.logger.debug(f"delta_uri={delta_uri}") - - is_delta_table = DeltaTable.is_deltatable(delta_uri, storage_options) - - self.logger.debug(f"is_delta_table={is_delta_table}") - - if is_delta_table: - return DeltaTable(delta_uri, storage_options=storage_options) - - return None - - def _run(self) -> dict[str, int]: - if self.refresh_dlt: - self.logger.info("Pipeline getting a full refresh due to reset_pipeline being set") - - pipeline = self._create_pipeline() - - # Workaround for full refresh schemas while we wait for Rust to fix memory issue - for name, resource in self.source._resources.items(): - if resource.write_disposition == "replace": - delta_table = self._get_delta_table(name) - - if delta_table is not None: - self.logger.debug("Deleting existing delta table") - delta_table.delete() - - self.logger.debug("Updating table write_disposition to append") - resource.apply_hints(write_disposition="append") - - total_counts: Counter[str] = Counter({}) - - # Do chunking for incremental syncing on API based endpoints (e.g. not sql databases) - if self.should_chunk_pipeline: - # will get overwritten - counts: Counter[str] = Counter({"start": 1}) - pipeline_runs = 0 - - while counts: - self.logger.info(f"Running incremental (non-sql) pipeline, run ${pipeline_runs}") - - try: - pipeline.run( - self.source, - loader_file_format=self.loader_file_format, - refresh="drop_sources" if self.refresh_dlt and pipeline_runs == 0 else None, - ) - except PipelineStepFailed as e: - # Remove once DLT support writing empty Delta files - if isinstance(e.exception, LoadClientJobRetry): - if "Generic S3 error" not in e.exception.retry_message: - raise - elif isinstance(e.exception, DeltaError): - if e.exception.args[0] != "Generic error: No data source supplied to write command.": - raise - else: - raise - - if pipeline.last_trace.last_normalize_info is not None: - row_counts = pipeline.last_trace.last_normalize_info.row_counts - else: - row_counts = {} - # Remove any DLT tables from the counts - filtered_rows = dict(filter(lambda pair: not pair[0].startswith("_dlt"), row_counts.items())) - counts = Counter(filtered_rows) - total_counts = counts + total_counts - - if total_counts.total() > 0: - # Fix to upgrade all tables to DeltaS3Wrapper - resouce_names = list(self.source._resources.keys()) - if len(resouce_names) > 0: - name = resouce_names[0] - table = self._get_delta_table(name) - if table is not None: - delta_tables = {name: table} - else: - delta_tables = get_delta_tables(pipeline) - else: - delta_tables = get_delta_tables(pipeline) - - table_format = DataWarehouseTable.TableFormat.DeltaS3Wrapper - - # Workaround while we fix msising table_format on DLT resource - if len(delta_tables.values()) == 0: - table_format = DataWarehouseTable.TableFormat.Delta - - # There should only ever be one table here - for table in delta_tables.values(): - self.logger.info("Compacting delta table") - table.optimize.compact() - table.vacuum(retention_hours=24, enforce_retention_duration=False, dry_run=False) - - file_uris = table.file_uris() - self.logger.info(f"Preparing S3 files - total parquet files: {len(file_uris)}") - self._prepare_s3_files_for_querying(file_uris) - - self.logger.info(f"Table format: {table_format}") - - validate_schema_and_update_table_sync( - run_id=self.inputs.run_id, - team_id=self.inputs.team_id, - schema_id=self.inputs.schema_id, - table_schema=self.source.schema.tables, - row_count=total_counts.total(), - table_format=table_format, - ) - else: - self.logger.info("No table_counts, skipping validate_schema_and_update_table") - - pipeline_runs = pipeline_runs + 1 - else: - self.logger.info("Running standard pipeline") - try: - pipeline.run( - self.source, - loader_file_format=self.loader_file_format, - refresh="drop_sources" if self.refresh_dlt else None, - ) - except PipelineStepFailed as e: - # Remove once DLT support writing empty Delta files - if isinstance(e.exception, LoadClientJobRetry): - if "Generic S3 error" not in e.exception.retry_message: - raise - elif isinstance(e.exception, DeltaError): - if e.exception.args[0] != "Generic error: No data source supplied to write command.": - raise - else: - raise - - if pipeline.last_trace.last_normalize_info is not None: - row_counts = pipeline.last_trace.last_normalize_info.row_counts - else: - row_counts = {} - - filtered_rows = dict(filter(lambda pair: not pair[0].startswith("_dlt"), row_counts.items())) - counts = Counter(filtered_rows) - total_counts = total_counts + counts - - if total_counts.total() > 0: - # Fix to upgrade all tables to DeltaS3Wrapper - resouce_names = list(self.source._resources.keys()) - if len(resouce_names) > 0: - name = resouce_names[0] - table = self._get_delta_table(name) - if table is not None: - delta_tables = {name: table} - else: - delta_tables = get_delta_tables(pipeline) - else: - delta_tables = get_delta_tables(pipeline) - - table_format = DataWarehouseTable.TableFormat.DeltaS3Wrapper - - # Workaround while we fix msising table_format on DLT resource - if len(delta_tables.values()) == 0: - table_format = DataWarehouseTable.TableFormat.Delta - - # There should only ever be one table here - for table in delta_tables.values(): - self.logger.info("Compacting delta table") - table.optimize.compact() - table.vacuum(retention_hours=24, enforce_retention_duration=False, dry_run=False) - - file_uris = table.file_uris() - self.logger.info(f"Preparing S3 files - total parquet files: {len(file_uris)}") - self._prepare_s3_files_for_querying(file_uris) - - self.logger.info(f"Table format: {table_format}") - - validate_schema_and_update_table_sync( - run_id=self.inputs.run_id, - team_id=self.inputs.team_id, - schema_id=self.inputs.schema_id, - table_schema=self.source.schema.tables, - row_count=total_counts.total(), - table_format=table_format, - ) - else: - self.logger.info("No table_counts, skipping validate_schema_and_update_table") - - # Update last_synced_at on schema - update_last_synced_at_sync( - job_id=self.inputs.run_id, schema_id=str(self.inputs.schema_id), team_id=self.inputs.team_id - ) - - if self._incremental: - self.logger.debug("Saving last incremental value...") - save_last_incremental_value(str(self.inputs.schema_id), str(self.inputs.team_id), self.source, self.logger) - - # Cleanup: delete local state from the file system - pipeline.drop() - - return dict(total_counts) - - def run(self) -> dict[str, int]: - try: - return self._run() - except PipelineStepFailed as e: - self.logger.exception(f"Data import failed for endpoint with exception {e}", exc_info=e) - raise - - def update_last_synced_at_sync(job_id: str, schema_id: str, team_id: int) -> None: job = ExternalDataJob.objects.prefetch_related( "pipeline", Prefetch("schema", queryset=ExternalDataSchema.objects.prefetch_related("source")) @@ -426,34 +65,6 @@ def update_last_synced_at_sync(job_id: str, schema_id: str, team_id: int) -> Non schema.save() -def save_last_incremental_value(schema_id: str, team_id: str, source: DltSource, logger: FilteringBoundLogger) -> None: - schema = ExternalDataSchema.objects.exclude(deleted=True).get(id=schema_id, team_id=team_id) - - incremental_field = schema.sync_type_config.get("incremental_field") - resource = next(iter(source.resources.values())) - - incremental: dict | None = resource.state.get("incremental") - - if incremental is None: - return - - incremental_object: dict | None = incremental.get(incremental_field) - if incremental_object is None: - return - - last_value = incremental_object.get("last_value") - - logger.debug(f"Updating incremental_field_last_value with {last_value}") - - if last_value is None: - logger.debug( - f"Incremental value is None. This could mean the table has zero rows. Full incremental object: {incremental_object}" - ) - return - - schema.update_incremental_field_last_value(last_value) - - def validate_schema_and_update_table_sync( run_id: str, team_id: int, @@ -486,18 +97,6 @@ def validate_schema_and_update_table_sync( "pipeline", Prefetch("schema", queryset=ExternalDataSchema.objects.prefetch_related("source")) ).get(pk=run_id) - using_v2_pipeline = job.pipeline_version == ExternalDataJob.PipelineVersion.V2 - pipeline_version = ( - ExternalDataJob.PipelineVersion.V1 - if job.pipeline_version is None - else ExternalDataJob.PipelineVersion(job.pipeline_version) - ) - - # Temp so we dont create a bunch of orphaned Table objects - if using_v2_pipeline: - logger.debug("Using V2 pipeline - dont create table object or get columns") - return - credential = get_or_create_datawarehouse_credential( team_id=team_id, access_key=settings.AIRBYTE_BUCKET_KEY, @@ -546,65 +145,21 @@ def validate_schema_and_update_table_sync( assert isinstance(table_created, DataWarehouseTable) and table_created is not None - # Temp fix #2 for Delta tables without table_format - if not using_v2_pipeline: - try: - table_created.get_columns() - except Exception as e: - if table_format == DataWarehouseTable.TableFormat.DeltaS3Wrapper: - logger.exception( - "get_columns exception with DeltaS3Wrapper format - trying Delta format", exc_info=e - ) - - table_created.format = DataWarehouseTable.TableFormat.Delta - table_created.get_columns() - table_created.save() - - logger.info("Delta format worked - updating table to use Delta") - else: - raise - - # If using new non-DLT pipeline - if using_v2_pipeline and table_schema_dict is not None: - raw_db_columns: dict[str, dict[str, str]] = table_created.get_columns(pipeline_version=pipeline_version) - db_columns = {key: column.get("clickhouse", "") for key, column in raw_db_columns.items()} - - columns = {} - for column_name, db_column_type in db_columns.items(): - hogql_type = table_schema_dict.get(column_name) - - if hogql_type is None: - raise Exception(f"HogQL type not found for column: {column_name}") - - columns[column_name] = { - "clickhouse": db_column_type, - "hogql": hogql_type, - } - table_created.columns = columns - else: - # If using DLT pipeline - for schema in table_schema.values(): - if schema.get("resource") == _schema_name: - schema_columns = schema.get("columns") or {} - raw_db_columns: dict[str, dict[str, str]] = table_created.get_columns() - db_columns = {key: column.get("clickhouse", "") for key, column in raw_db_columns.items()} - - columns = {} - for column_name, db_column_type in db_columns.items(): - dlt_column = schema_columns.get(column_name) - if dlt_column is not None: - dlt_data_type = dlt_column.get("data_type") - hogql_type = dlt_to_hogql_type(dlt_data_type) - else: - hogql_type = dlt_to_hogql_type(None) - - columns[column_name] = { - "clickhouse": db_column_type, - "hogql": hogql_type, - } - table_created.columns = columns - break + raw_db_columns: dict[str, dict[str, str]] = table_created.get_columns() + db_columns = {key: column.get("clickhouse", "") for key, column in raw_db_columns.items()} + columns = {} + for column_name, db_column_type in db_columns.items(): + hogql_type = table_schema_dict.get(column_name) + + if hogql_type is None: + raise Exception(f"HogQL type not found for column: {column_name}") + + columns[column_name] = { + "clickhouse": db_column_type, + "hogql": hogql_type, + } + table_created.columns = columns table_created.save() # schema could have been deleted by this point @@ -614,9 +169,8 @@ def validate_schema_and_update_table_sync( .get(id=_schema_id, team_id=team_id) ) - if not using_v2_pipeline and schema_model: - schema_model.table = table_created - schema_model.save() + schema_model.table = table_created + schema_model.save() except ServerException as err: if err.code == 636: diff --git a/posthog/temporal/data_imports/pipelines/rest_source/__init__.py b/posthog/temporal/data_imports/pipelines/rest_source/__init__.py index 9a8599882c652..044ade4c93033 100644 --- a/posthog/temporal/data_imports/pipelines/rest_source/__init__.py +++ b/posthog/temporal/data_imports/pipelines/rest_source/__init__.py @@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator, Iterator from collections.abc import Callable import graphlib # type: ignore[import,unused-ignore] +from dateutil import parser import dlt from dlt.common.validation import validate_dict @@ -29,6 +30,8 @@ Endpoint, EndpointResource, RESTAPIConfig, + TAnySchemaColumns, + TTableHintTemplate, ) from .config_setup import ( IncrementalParam, @@ -42,6 +45,27 @@ from .utils import exclude_keys # noqa: F401 +def convert_types( + data: Iterator[Any] | list[Any], types: Optional[dict[str, dict[str, Any]]] +) -> Iterator[dict[str, Any]]: + if types is None: + yield from data + return + + for item in data: + for key, column in types.items(): + data_type = column.get("data_type") + + if key in item: + current_value = item.get(key) + if data_type == "timestamp" and isinstance(current_value, str): + item[key] = parser.parse(current_value) + elif data_type == "date" and isinstance(current_value, str): + item[key] = parser.parse(current_value).date() + + yield item + + def rest_api_source( config: RESTAPIConfig, team_id: int, @@ -246,6 +270,8 @@ def create_resources( resource_kwargs = exclude_keys(endpoint_resource, {"endpoint", "include_from_parent"}) + columns_config = endpoint_resource.get("columns") + if resolved_param is None: async def paginate_resource( @@ -257,6 +283,7 @@ async def paginate_resource( data_selector: Optional[jsonpath.TJsonPath], hooks: Optional[dict[str, Any]], client: RESTClient = client, + columns_config: Optional[TTableHintTemplate[TAnySchemaColumns]] = None, incremental_object: Optional[Incremental[Any]] = incremental_object, incremental_param: Optional[IncrementalParam] = incremental_param, incremental_cursor_transform: Optional[Callable[..., Any]] = incremental_cursor_transform, @@ -272,14 +299,17 @@ async def paginate_resource( db_incremental_field_last_value, ) - yield client.paginate( - method=method, - path=path, - params=params, - json=json, - paginator=paginator, - data_selector=data_selector, - hooks=hooks, + yield convert_types( + client.paginate( + method=method, + path=path, + params=params, + json=json, + paginator=paginator, + data_selector=data_selector, + hooks=hooks, + ), + columns_config, ) resources[resource_name] = dlt.resource( @@ -293,6 +323,7 @@ async def paginate_resource( paginator=paginator, data_selector=endpoint_config.get("data_selector"), hooks=hooks, + columns_config=columns_config, ) else: @@ -311,6 +342,7 @@ async def paginate_dependent_resource( client: RESTClient = client, resolved_param: ResolvedParam = resolved_param, include_from_parent: list[str] = include_from_parent, + columns_config: Optional[TTableHintTemplate[TAnySchemaColumns]] = None, incremental_object: Optional[Incremental[Any]] = incremental_object, incremental_param: Optional[IncrementalParam] = incremental_param, incremental_cursor_transform: Optional[Callable[..., Any]] = incremental_cursor_transform, @@ -342,7 +374,8 @@ async def paginate_dependent_resource( if parent_record: for child_record in child_page: child_record.update(parent_record) - yield child_page + + yield convert_types(child_page, columns_config) resources[resource_name] = dlt.resource( # type: ignore[call-overload] paginate_dependent_resource, @@ -355,6 +388,7 @@ async def paginate_dependent_resource( paginator=paginator, data_selector=endpoint_config.get("data_selector"), hooks=hooks, + columns_config=columns_config, ) return resources diff --git a/posthog/temporal/data_imports/pipelines/test/test_pipeline_sync.py b/posthog/temporal/data_imports/pipelines/test/test_pipeline_sync.py deleted file mode 100644 index 5b765e35cea14..0000000000000 --- a/posthog/temporal/data_imports/pipelines/test/test_pipeline_sync.py +++ /dev/null @@ -1,152 +0,0 @@ -import functools -from typing import Any -from unittest.mock import MagicMock, PropertyMock, patch -import uuid - -import boto3 -import pytest -import structlog -from django.conf import settings -from django.test import override_settings -from posthog.temporal.data_imports.pipelines.pipeline_sync import DataImportPipelineSync, PipelineInputs -from posthog.temporal.data_imports.pipelines.stripe import stripe_source -from posthog.test.base import APIBaseTest -from posthog.warehouse.models.external_data_job import ExternalDataJob -from posthog.warehouse.models.external_data_schema import ExternalDataSchema -from posthog.warehouse.models.external_data_source import ExternalDataSource - - -BUCKET_NAME = "test-pipeline-sync" -SESSION = boto3.Session() -create_test_client = functools.partial(SESSION.client, endpoint_url=settings.OBJECT_STORAGE_ENDPOINT) - - -class TestDataImportPipeline(APIBaseTest): - @pytest.fixture(autouse=True) - def minio_client(self): - """Manage an S3 client to interact with a MinIO bucket. - - Yields the client after creating a bucket. Upon resuming, we delete - the contents and the bucket itself. - """ - minio_client = create_test_client( - "s3", - aws_access_key_id=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - aws_secret_access_key=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, - ) - - try: - minio_client.head_bucket(Bucket=BUCKET_NAME) - except: - minio_client.create_bucket(Bucket=BUCKET_NAME) - - yield minio_client - - def _create_pipeline(self, schema_name: str, incremental: bool): - source = ExternalDataSource.objects.create( - source_id=str(uuid.uuid4()), - connection_id=str(uuid.uuid4()), - destination_id=str(uuid.uuid4()), - team=self.team, - status="running", - source_type="Stripe", - ) - schema = ExternalDataSchema.objects.create( - name=schema_name, - team_id=self.team.pk, - source_id=source.pk, - source=source, - ) - job = ExternalDataJob.objects.create( - team_id=self.team.pk, - pipeline_id=source.pk, - pipeline=source, - schema_id=schema.pk, - schema=schema, - status=ExternalDataJob.Status.RUNNING, - rows_synced=0, - workflow_id=str(uuid.uuid4()), - pipeline_version=ExternalDataJob.PipelineVersion.V1, - ) - - pipeline = DataImportPipelineSync( - inputs=PipelineInputs( - source_id=source.pk, - run_id=str(job.pk), - schema_id=schema.pk, - dataset_name=job.folder_path(), - job_type=ExternalDataSource.Type.STRIPE, - team_id=self.team.pk, - ), - source=stripe_source( - api_key="", - account_id="", - endpoint=schema_name, - is_incremental=incremental, - team_id=self.team.pk, - job_id=str(job.pk), - db_incremental_field_last_value=0, - ), - logger=structlog.get_logger(), - incremental=incremental, - reset_pipeline=False, - ) - - return pipeline - - @pytest.mark.django_db(transaction=True) - def test_pipeline_non_incremental(self): - def mock_create_pipeline(local_self: Any): - mock = MagicMock() - mock.last_trace.last_normalize_info.row_counts = {"customer": 1} - return mock - - with ( - patch.object(DataImportPipelineSync, "_create_pipeline", mock_create_pipeline), - patch( - "posthog.temporal.data_imports.pipelines.pipeline_sync.validate_schema_and_update_table_sync" - ) as mock_validate_schema_and_update_table, - patch("posthog.temporal.data_imports.pipelines.pipeline_sync.get_delta_tables"), - patch("posthog.temporal.data_imports.pipelines.pipeline_sync.update_last_synced_at_sync"), - override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, - AIRBYTE_BUCKET_REGION="us-east-1", - AIRBYTE_BUCKET_DOMAIN="objectstorage:19000", - ), - ): - pipeline = self._create_pipeline("Customer", False) - res = pipeline.run() - - assert res.get("customer") == 1 - assert mock_validate_schema_and_update_table.call_count == 1 - - @pytest.mark.django_db(transaction=True) - def test_pipeline_incremental(self): - def mock_create_pipeline(local_self: Any): - mock = MagicMock() - type(mock.last_trace.last_normalize_info).row_counts = PropertyMock(side_effect=[{"customer": 1}, {}]) - return mock - - with ( - patch.object(DataImportPipelineSync, "_create_pipeline", mock_create_pipeline), - patch( - "posthog.temporal.data_imports.pipelines.pipeline_sync.validate_schema_and_update_table_sync" - ) as mock_validate_schema_and_update_table, - patch("posthog.temporal.data_imports.pipelines.pipeline_sync.get_delta_tables"), - patch("posthog.temporal.data_imports.pipelines.pipeline_sync.update_last_synced_at_sync"), - patch("posthog.temporal.data_imports.pipelines.pipeline_sync.save_last_incremental_value"), - override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, - AIRBYTE_BUCKET_REGION="us-east-1", - AIRBYTE_BUCKET_DOMAIN="objectstorage:19000", - ), - ): - pipeline = self._create_pipeline("Customer", True) - res = pipeline.run() - - assert res.get("customer") == 1 - assert mock_validate_schema_and_update_table.call_count == 2 diff --git a/posthog/temporal/data_imports/util.py b/posthog/temporal/data_imports/util.py index 1900c237c1c2a..bfad031b2669d 100644 --- a/posthog/temporal/data_imports/util.py +++ b/posthog/temporal/data_imports/util.py @@ -19,15 +19,9 @@ def prepare_s3_files_for_querying( s3_folder_for_job = f"{settings.BUCKET_URL}/{folder_path}" - if pipeline_version == ExternalDataJob.PipelineVersion.V2: - s3_folder_for_schema = f"{s3_folder_for_job}/{normalized_table_name}__v2" - else: - s3_folder_for_schema = f"{s3_folder_for_job}/{normalized_table_name}" - - if pipeline_version == ExternalDataJob.PipelineVersion.V2: - s3_folder_for_querying = f"{s3_folder_for_job}/{normalized_table_name}__query_v2" - else: - s3_folder_for_querying = f"{s3_folder_for_job}/{normalized_table_name}__query" + s3_folder_for_schema = f"{s3_folder_for_job}/{normalized_table_name}" + + s3_folder_for_querying = f"{s3_folder_for_job}/{normalized_table_name}__query" if s3.exists(s3_folder_for_querying): s3.delete(s3_folder_for_querying, recursive=True) diff --git a/posthog/temporal/data_imports/workflow_activities/create_job_model.py b/posthog/temporal/data_imports/workflow_activities/create_job_model.py index 729ff9960ef44..6b98a3b221d6d 100644 --- a/posthog/temporal/data_imports/workflow_activities/create_job_model.py +++ b/posthog/temporal/data_imports/workflow_activities/create_job_model.py @@ -1,13 +1,11 @@ import dataclasses import uuid -from django.conf import settings from django.db import close_old_connections from temporalio import activity # TODO: remove dependency -from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE_V2 from posthog.warehouse.data_load.service import delete_external_data_schedule from posthog.warehouse.models import ExternalDataJob, ExternalDataSource from posthog.warehouse.models.external_data_schema import ( @@ -23,13 +21,6 @@ class CreateExternalDataJobModelActivityInputs: source_id: uuid.UUID -def get_pipeline_version() -> str: - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2: - return ExternalDataJob.PipelineVersion.V2 - - return ExternalDataJob.PipelineVersion.V1 - - @activity.defn def create_external_data_job_model_activity( inputs: CreateExternalDataJobModelActivityInputs, @@ -46,8 +37,6 @@ def create_external_data_job_model_activity( delete_external_data_schedule(str(inputs.schema_id)) raise Exception("Source or schema no longer exists - deleted temporal schedule") - pipeline_version = get_pipeline_version() - job = ExternalDataJob.objects.create( team_id=inputs.team_id, pipeline_id=inputs.source_id, @@ -56,8 +45,8 @@ def create_external_data_job_model_activity( rows_synced=0, workflow_id=activity.info().workflow_id, workflow_run_id=activity.info().workflow_run_id, - pipeline_version=pipeline_version, - billable=pipeline_version != ExternalDataJob.PipelineVersion.V2, + pipeline_version=ExternalDataJob.PipelineVersion.V2, + billable=True, ) schema = ExternalDataSchema.objects.get(team_id=inputs.team_id, id=inputs.schema_id) diff --git a/posthog/temporal/data_imports/workflow_activities/import_data_sync.py b/posthog/temporal/data_imports/workflow_activities/import_data_sync.py index a9a058bb52261..db09728acce55 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data_sync.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data_sync.py @@ -4,20 +4,17 @@ from dateutil import parser from typing import Any -from django.conf import settings from django.db import close_old_connections -from django.db.models import Prefetch, F +from django.db.models import Prefetch from temporalio import activity -from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE_V2 from posthog.models.integration import Integration from posthog.temporal.common.heartbeat_sync import HeartbeaterSync from posthog.temporal.data_imports.pipelines.bigquery import delete_all_temp_destination_tables, delete_table from posthog.temporal.data_imports.pipelines.pipeline.pipeline import PipelineNonDLT -from posthog.temporal.data_imports.pipelines.pipeline_sync import DataImportPipelineSync, PipelineInputs -from posthog.temporal.data_imports.util import is_posthog_team, is_enabled_for_team +from posthog.temporal.data_imports.pipelines.pipeline_sync import PipelineInputs from posthog.warehouse.models import ( ExternalDataJob, ExternalDataSource, @@ -100,23 +97,10 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): endpoints = [schema.name] - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2: - # Get the V2 last value, if it's not set yet (e.g. the first run), then fallback to the V1 value - processed_incremental_last_value = process_incremental_last_value( - schema.sync_type_config.get("incremental_field_last_value_v2"), - schema.sync_type_config.get("incremental_field_type"), - ) - - if processed_incremental_last_value is None: - processed_incremental_last_value = process_incremental_last_value( - schema.sync_type_config.get("incremental_field_last_value"), - schema.sync_type_config.get("incremental_field_type"), - ) - else: - processed_incremental_last_value = process_incremental_last_value( - schema.sync_type_config.get("incremental_field_last_value"), - schema.sync_type_config.get("incremental_field_type"), - ) + processed_incremental_last_value = processed_incremental_last_value = process_incremental_last_value( + schema.sync_type_config.get("incremental_field_last_value"), + schema.sync_type_config.get("incremental_field_type"), + ) if schema.is_incremental: logger.debug(f"Incremental last value being used is: {processed_incremental_last_value}") @@ -181,14 +165,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): ExternalDataSource.Type.MYSQL, ExternalDataSource.Type.MSSQL, ]: - if ( - is_posthog_team(inputs.team_id) - or is_enabled_for_team(inputs.team_id) - or settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2 - ): - from posthog.temporal.data_imports.pipelines.sql_database_v2 import sql_source_for_type - else: - from posthog.temporal.data_imports.pipelines.sql_database import sql_source_for_type + from posthog.temporal.data_imports.pipelines.sql_database_v2 import sql_source_for_type host = model.pipeline.job_inputs.get("host") port = model.pipeline.job_inputs.get("port") @@ -284,14 +261,9 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): reset_pipeline=reset_pipeline, ) elif model.pipeline.source_type == ExternalDataSource.Type.SNOWFLAKE: - if is_posthog_team(inputs.team_id): - from posthog.temporal.data_imports.pipelines.sql_database_v2 import ( - snowflake_source, - ) - else: - from posthog.temporal.data_imports.pipelines.sql_database import ( - snowflake_source, - ) + from posthog.temporal.data_imports.pipelines.sql_database_v2 import ( + snowflake_source, + ) account_id = model.pipeline.job_inputs.get("account_id") database = model.pipeline.job_inputs.get("database") @@ -527,19 +499,9 @@ def _run( schema: ExternalDataSchema, reset_pipeline: bool, ): - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2: - pipeline = PipelineNonDLT(source, logger, job_inputs.run_id, schema.is_incremental, reset_pipeline) - pipeline.run() - del pipeline - else: - table_row_counts = DataImportPipelineSync( - job_inputs, source, logger, reset_pipeline, schema.is_incremental - ).run() - total_rows_synced = sum(table_row_counts.values()) - - ExternalDataJob.objects.filter(id=inputs.run_id, team_id=inputs.team_id).update( - rows_synced=F("rows_synced") + total_rows_synced - ) + pipeline = PipelineNonDLT(source, logger, job_inputs.run_id, schema.is_incremental, reset_pipeline) + pipeline.run() + del pipeline source = ExternalDataSource.objects.get(id=inputs.source_id) source.job_inputs.pop("reset_pipeline", None) diff --git a/posthog/temporal/tests/data_imports/test_end_to_end.py b/posthog/temporal/tests/data_imports/test_end_to_end.py index 5bebcd72a5f57..381aa025817d0 100644 --- a/posthog/temporal/tests/data_imports/test_end_to_end.py +++ b/posthog/temporal/tests/data_imports/test_end_to_end.py @@ -21,7 +21,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker -from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE, DATA_WAREHOUSE_TASK_QUEUE_V2 +from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE from posthog.hogql.modifiers import create_default_modifiers_for_team from posthog.hogql.query import execute_hogql_query from posthog.hogql_queries.insights.funnels.funnel import Funnel @@ -101,19 +101,6 @@ async def minio_client(): yield minio_client -def pytest_generate_tests(metafunc): - if "task_queue" in metafunc.fixturenames: - metafunc.parametrize("task_queue", [DATA_WAREHOUSE_TASK_QUEUE, DATA_WAREHOUSE_TASK_QUEUE_V2], indirect=True) - - -@pytest.fixture(autouse=True) -def task_queue(request): - queue = getattr(request, "param", None) - - with override_settings(TEMPORAL_TASK_QUEUE=queue): - yield - - async def _run( team: Team, schema_name: str, @@ -158,22 +145,18 @@ async def _run( await sync_to_async(schema.refresh_from_db)() - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - assert schema.last_synced_at == run.created_at - else: - assert schema.last_synced_at is None + assert schema.last_synced_at == run.created_at - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM {table_name}", team) - assert len(res.results) == 1 + res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM {table_name}", team) + assert len(res.results) == 1 - for name, field in external_tables.get(table_name, {}).items(): - if field.hidden: - continue - assert name in (res.columns or []) + for name, field in external_tables.get(table_name, {}).items(): + if field.hidden: + continue + assert name in (res.columns or []) - await sync_to_async(source.refresh_from_db)() - assert source.job_inputs.get("reset_pipeline", None) is None + await sync_to_async(source.refresh_from_db)() + assert source.job_inputs.get("reset_pipeline", None) is None return workflow_id, inputs @@ -234,12 +217,11 @@ def mock_to_object_store_rs_credentials(class_self): ), mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), - mock.patch("posthog.temporal.data_imports.external_data_job.trigger_pipeline_v2"), ): async with await WorkflowEnvironment.start_time_skipping() as activity_environment: async with Worker( activity_environment.client, - task_queue=settings.TEMPORAL_TASK_QUEUE, + task_queue=DATA_WAREHOUSE_TASK_QUEUE, workflows=[ExternalDataJobWorkflow], activities=ACTIVITIES, # type: ignore workflow_runner=UnsandboxedWorkflowRunner(), @@ -250,7 +232,7 @@ def mock_to_object_store_rs_credentials(class_self): ExternalDataJobWorkflow.run, inputs, id=workflow_id, - task_queue=settings.TEMPORAL_TASK_QUEUE, + task_queue=DATA_WAREHOUSE_TASK_QUEUE, retry_policy=RetryPolicy(maximum_attempts=1), ) @@ -557,13 +539,13 @@ async def test_postgres_binary_columns(team, postgres_config, postgres_connectio mock_data_response=[], ) - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM postgres_binary_col_test", team) - columns = res.columns + res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM postgres_binary_col_test", team) + columns = res.columns - assert columns is not None - assert len(columns) == 1 - assert columns[0] == "id" + assert columns is not None + assert len(columns) == 2 + assert any(x == "_ph_debug" for x in columns) + assert any(x == "id" for x in columns) @pytest.mark.django_db(transaction=True) @@ -591,14 +573,9 @@ def get_jobs(): latest_job = jobs[0] folder_path = await sync_to_async(latest_job.folder_path)() - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - s3_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{folder_path}/balance_transaction__query/" - ) - else: - s3_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{folder_path}/balance_transaction__query_v2/" - ) + s3_objects = await minio_client.list_objects_v2( + Bucket=BUCKET_NAME, Prefix=f"{folder_path}/balance_transaction__query/" + ) assert len(s3_objects["Contents"]) != 0 @@ -625,24 +602,23 @@ async def test_funnels_lazy_joins_ordering(team, stripe_customer): field_name="stripe_customer", ) - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - query = FunnelsQuery( - series=[EventsNode(), EventsNode()], - breakdownFilter=BreakdownFilter( - breakdown_type=BreakdownType.DATA_WAREHOUSE_PERSON_PROPERTY, breakdown="stripe_customer.email" - ), - ) - funnel_class = Funnel(context=FunnelQueryContext(query=query, team=team)) + query = FunnelsQuery( + series=[EventsNode(), EventsNode()], + breakdownFilter=BreakdownFilter( + breakdown_type=BreakdownType.DATA_WAREHOUSE_PERSON_PROPERTY, breakdown="stripe_customer.email" + ), + ) + funnel_class = Funnel(context=FunnelQueryContext(query=query, team=team)) - query_ast = funnel_class.get_query() - await sync_to_async(execute_hogql_query)( - query_type="FunnelsQuery", - query=query_ast, - team=team, - modifiers=create_default_modifiers_for_team( - team, HogQLQueryModifiers(personsOnEventsMode=PersonsOnEventsMode.PERSON_ID_OVERRIDE_PROPERTIES_JOINED) - ), - ) + query_ast = funnel_class.get_query() + await sync_to_async(execute_hogql_query)( + query_type="FunnelsQuery", + query=query_ast, + team=team, + modifiers=create_default_modifiers_for_team( + team, HogQLQueryModifiers(personsOnEventsMode=PersonsOnEventsMode.PERSON_ID_OVERRIDE_PROPERTIES_JOINED) + ), + ) @pytest.mark.django_db(transaction=True) @@ -675,13 +651,12 @@ async def test_postgres_schema_evolution(team, postgres_config, postgres_connect sync_type_config={"incremental_field": "id", "incremental_field_type": "integer"}, ) - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team) - columns = res.columns + res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team) + columns = res.columns - assert columns is not None - assert len(columns) == 1 - assert any(x == "id" for x in columns) + assert columns is not None + assert len(columns) == 2 + assert any(x == "id" for x in columns) # Evole schema await postgres_connection.execute( @@ -695,14 +670,13 @@ async def test_postgres_schema_evolution(team, postgres_config, postgres_connect # Execute the same schema again - load await _execute_run(str(uuid.uuid4()), inputs, []) - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team) - columns = res.columns + res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team) + columns = res.columns - assert columns is not None - assert len(columns) == 2 - assert any(x == "id" for x in columns) - assert any(x == "new_col" for x in columns) + assert columns is not None + assert len(columns) == 3 + assert any(x == "id" for x in columns) + assert any(x == "new_col" for x in columns) @pytest.mark.django_db(transaction=True) @@ -739,16 +713,15 @@ async def test_sql_database_missing_incremental_values(team, postgres_config, po sync_type_config={"incremental_field": "id", "incremental_field_type": "integer"}, ) - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team) - columns = res.columns + res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team) + columns = res.columns - assert columns is not None - assert len(columns) == 1 - assert any(x == "id" for x in columns) + assert columns is not None + assert len(columns) == 2 + assert any(x == "id" for x in columns) - # Exclude rows that don't have the incremental cursor key set - assert len(res.results) == 1 + # Exclude rows that don't have the incremental cursor key set + assert len(res.results) == 1 @pytest.mark.django_db(transaction=True) @@ -782,16 +755,15 @@ async def test_sql_database_incremental_initial_value(team, postgres_config, pos sync_type_config={"incremental_field": "id", "incremental_field_type": "integer"}, ) - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team) - columns = res.columns + res = await sync_to_async(execute_hogql_query)("SELECT * FROM postgres_test_table", team) + columns = res.columns - assert columns is not None - assert len(columns) == 1 - assert any(x == "id" for x in columns) + assert columns is not None + assert len(columns) == 2 + assert any(x == "id" for x in columns) - # Include rows that have the same incremental value as the `initial_value` - assert len(res.results) == 1 + # Include rows that have the same incremental value as the `initial_value` + assert len(res.results) == 1 @pytest.mark.django_db(transaction=True) @@ -1038,26 +1010,6 @@ async def test_non_retryable_error_with_special_characters(team, stripe_customer await sync_to_async(execute_hogql_query)("SELECT * FROM stripe_customer", team) -@pytest.mark.django_db(transaction=True) -@pytest.mark.asyncio -async def test_delta_table_deleted(team, stripe_balance_transaction): - workflow_id, inputs = await _run( - team=team, - schema_name="BalanceTransaction", - table_name="stripe_balancetransaction", - source_type="Stripe", - job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"}, - mock_data_response=stripe_balance_transaction["data"], - sync_type=ExternalDataSchema.SyncType.FULL_REFRESH, - ) - - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - with mock.patch.object(DeltaTable, "delete") as mock_delta_table_delete: - await _execute_run(str(uuid.uuid4()), inputs, stripe_balance_transaction["data"]) - - mock_delta_table_delete.assert_called_once() - - @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_inconsistent_types_in_data(team, stripe_balance_transaction): @@ -1112,52 +1064,51 @@ async def test_postgres_uuid_type(team, postgres_config, postgres_connection): @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_decimal_down_scales(team, postgres_config, postgres_connection): - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2: - await postgres_connection.execute( - "CREATE TABLE IF NOT EXISTS {schema}.downsizing_column (id integer, dec_col numeric(10, 2))".format( - schema=postgres_config["schema"] - ) + await postgres_connection.execute( + "CREATE TABLE IF NOT EXISTS {schema}.downsizing_column (id integer, dec_col numeric(10, 2))".format( + schema=postgres_config["schema"] ) - await postgres_connection.execute( - "INSERT INTO {schema}.downsizing_column (id, dec_col) VALUES (1, 12345.60)".format( - schema=postgres_config["schema"] - ) + ) + await postgres_connection.execute( + "INSERT INTO {schema}.downsizing_column (id, dec_col) VALUES (1, 12345.60)".format( + schema=postgres_config["schema"] ) + ) - await postgres_connection.commit() + await postgres_connection.commit() - workflow_id, inputs = await _run( - team=team, - schema_name="downsizing_column", - table_name="postgres_downsizing_column", - source_type="Postgres", - job_inputs={ - "host": postgres_config["host"], - "port": postgres_config["port"], - "database": postgres_config["database"], - "user": postgres_config["user"], - "password": postgres_config["password"], - "schema": postgres_config["schema"], - "ssh_tunnel_enabled": "False", - }, - mock_data_response=[], - ) + workflow_id, inputs = await _run( + team=team, + schema_name="downsizing_column", + table_name="postgres_downsizing_column", + source_type="Postgres", + job_inputs={ + "host": postgres_config["host"], + "port": postgres_config["port"], + "database": postgres_config["database"], + "user": postgres_config["user"], + "password": postgres_config["password"], + "schema": postgres_config["schema"], + "ssh_tunnel_enabled": "False", + }, + mock_data_response=[], + ) - await postgres_connection.execute( - "ALTER TABLE {schema}.downsizing_column ALTER COLUMN dec_col type numeric(9, 2) using dec_col::numeric(9, 2);".format( - schema=postgres_config["schema"] - ) + await postgres_connection.execute( + "ALTER TABLE {schema}.downsizing_column ALTER COLUMN dec_col type numeric(9, 2) using dec_col::numeric(9, 2);".format( + schema=postgres_config["schema"] ) + ) - await postgres_connection.execute( - "INSERT INTO {schema}.downsizing_column (id, dec_col) VALUES (1, 1234567.89)".format( - schema=postgres_config["schema"] - ) + await postgres_connection.execute( + "INSERT INTO {schema}.downsizing_column (id, dec_col) VALUES (1, 1234567.89)".format( + schema=postgres_config["schema"] ) + ) - await postgres_connection.commit() + await postgres_connection.commit() - await _execute_run(str(uuid.uuid4()), inputs, []) + await _execute_run(str(uuid.uuid4()), inputs, []) @pytest.mark.django_db(transaction=True) @@ -1221,51 +1172,55 @@ async def test_postgres_nan_numerical_values(team, postgres_config, postgres_con mock_data_response=[], ) - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE: - res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM postgres_numerical_nan", team) - columns = res.columns - results = res.results + res = await sync_to_async(execute_hogql_query)(f"SELECT * FROM postgres_numerical_nan", team) + columns = res.columns + results = res.results - assert columns is not None - assert len(columns) == 2 - assert columns[0] == "id" - assert columns[1] == "nan_column" + assert columns is not None + assert len(columns) == 3 + assert any(x == "_ph_debug" for x in columns) + assert any(x == "id" for x in columns) + assert any(x == "nan_column" for x in columns) - assert results is not None - assert len(results) == 1 - assert results[0] == (1, None) + assert results is not None + assert len(results) == 1 + + id_index = columns.index("id") + nan_index = columns.index("nan_column") + + assert results[0][id_index] == 1 + assert results[0][nan_index] is None @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_delete_table_on_reset(team, stripe_balance_transaction): - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2: - with ( - mock.patch.object(DeltaTable, "delete") as mock_delta_table_delete, - mock.patch.object(s3fs.S3FileSystem, "delete") as mock_s3_delete, - ): - workflow_id, inputs = await _run( - team=team, - schema_name="BalanceTransaction", - table_name="stripe_balancetransaction", - source_type="Stripe", - job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id", "reset_pipeline": "True"}, - mock_data_response=stripe_balance_transaction["data"], - ) + with ( + mock.patch.object(DeltaTable, "delete") as mock_delta_table_delete, + mock.patch.object(s3fs.S3FileSystem, "delete") as mock_s3_delete, + ): + workflow_id, inputs = await _run( + team=team, + schema_name="BalanceTransaction", + table_name="stripe_balancetransaction", + source_type="Stripe", + job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id", "reset_pipeline": "True"}, + mock_data_response=stripe_balance_transaction["data"], + ) - source = await sync_to_async(ExternalDataSource.objects.get)(id=inputs.external_data_source_id) + source = await sync_to_async(ExternalDataSource.objects.get)(id=inputs.external_data_source_id) - assert source.job_inputs is not None and isinstance(source.job_inputs, dict) - source.job_inputs["reset_pipeline"] = "True" + assert source.job_inputs is not None and isinstance(source.job_inputs, dict) + source.job_inputs["reset_pipeline"] = "True" - await sync_to_async(source.save)() + await sync_to_async(source.save)() - await _execute_run(str(uuid.uuid4()), inputs, stripe_balance_transaction["data"]) + await _execute_run(str(uuid.uuid4()), inputs, stripe_balance_transaction["data"]) - mock_delta_table_delete.assert_called() - mock_s3_delete.assert_called() + mock_delta_table_delete.assert_called() + mock_s3_delete.assert_called() - await sync_to_async(source.refresh_from_db)() + await sync_to_async(source.refresh_from_db)() - assert source.job_inputs is not None and isinstance(source.job_inputs, dict) - assert "reset_pipeline" not in source.job_inputs.keys() + assert source.job_inputs is not None and isinstance(source.job_inputs, dict) + assert "reset_pipeline" not in source.job_inputs.keys() diff --git a/posthog/temporal/tests/external_data/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py index 103513662daeb..d54703d7419eb 100644 --- a/posthog/temporal/tests/external_data/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -1,4 +1,5 @@ from concurrent.futures import ThreadPoolExecutor +import os import uuid from unittest import mock from typing import Any, Optional @@ -16,7 +17,7 @@ ExternalDataJobWorkflow, ExternalDataWorkflowInputs, ) -from posthog.temporal.data_imports.pipelines.pipeline_sync import DataImportPipelineSync +from posthog.temporal.data_imports.pipelines.pipeline.pipeline import PipelineNonDLT from posthog.temporal.data_imports.workflow_activities.check_billing_limits import check_billing_limits_activity from posthog.temporal.data_imports.workflow_activities.create_job_model import ( CreateExternalDataJobModelActivityInputs, @@ -511,9 +512,24 @@ def mock_to_object_store_rs_credentials(class_self): AIRBYTE_BUCKET_REGION="us-east-1", BUCKET_NAME=BUCKET_NAME, ), + # Mock os.environ for the deltalake subprocess + mock.patch.dict( + os.environ, + { + "BUCKET_URL": f"s3://{BUCKET_NAME}", + "AIRBYTE_BUCKET_KEY": settings.OBJECT_STORAGE_ACCESS_KEY_ID, + "AIRBYTE_BUCKET_SECRET": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + "AIRBYTE_BUCKET_REGION": "us-east-1", + "BUCKET_NAME": BUCKET_NAME, + "AIRBYTE_BUCKET_DOMAIN": "objectstorage:19000", + }, + ), mock.patch( "posthog.warehouse.models.table.DataWarehouseTable.get_columns", - return_value={"clickhouse": {"id": "string", "name": "string"}}, + return_value={ + "id": {"clickhouse": "string", "hogql": "StringDatabaseField"}, + "name": {"clickhouse": "string", "hogql": "StringDatabaseField"}, + }, ), mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), @@ -523,7 +539,7 @@ def mock_to_object_store_rs_credentials(class_self): folder_path = job_1.folder_path() job_1_customer_objects = minio_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/") - assert len(job_1_customer_objects["Contents"]) == 2 + assert len(job_1_customer_objects["Contents"]) == 3 with ( mock.patch.object(RESTClient, "paginate", mock_charges_paginate), @@ -534,9 +550,24 @@ def mock_to_object_store_rs_credentials(class_self): AIRBYTE_BUCKET_REGION="us-east-1", BUCKET_NAME=BUCKET_NAME, ), + # Mock os.environ for the deltalake subprocess + mock.patch.dict( + os.environ, + { + "BUCKET_URL": f"s3://{BUCKET_NAME}", + "AIRBYTE_BUCKET_KEY": settings.OBJECT_STORAGE_ACCESS_KEY_ID, + "AIRBYTE_BUCKET_SECRET": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + "AIRBYTE_BUCKET_REGION": "us-east-1", + "BUCKET_NAME": BUCKET_NAME, + "AIRBYTE_BUCKET_DOMAIN": "objectstorage:19000", + }, + ), mock.patch( "posthog.warehouse.models.table.DataWarehouseTable.get_columns", - return_value={"clickhouse": {"id": "string", "name": "string"}}, + return_value={ + "id": {"clickhouse": "string", "hogql": "StringDatabaseField"}, + "customer": {"clickhouse": "string", "hogql": "StringDatabaseField"}, + }, ), mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), @@ -544,7 +575,7 @@ def mock_to_object_store_rs_credentials(class_self): activity_environment.run(import_data_activity_sync, job_2_inputs) job_2_charge_objects = minio_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=f"{job_2.folder_path()}/charge/") - assert len(job_2_charge_objects["Contents"]) == 2 + assert len(job_2_charge_objects["Contents"]) == 3 @pytest.mark.django_db(transaction=True) @@ -635,9 +666,24 @@ def mock_to_object_store_rs_credentials(class_self): AIRBYTE_BUCKET_REGION="us-east-1", BUCKET_NAME=BUCKET_NAME, ), + # Mock os.environ for the deltalake subprocess + mock.patch.dict( + os.environ, + { + "BUCKET_URL": f"s3://{BUCKET_NAME}", + "AIRBYTE_BUCKET_KEY": settings.OBJECT_STORAGE_ACCESS_KEY_ID, + "AIRBYTE_BUCKET_SECRET": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + "AIRBYTE_BUCKET_REGION": "us-east-1", + "BUCKET_NAME": BUCKET_NAME, + "AIRBYTE_BUCKET_DOMAIN": "objectstorage:19000", + }, + ), mock.patch( "posthog.warehouse.models.table.DataWarehouseTable.get_columns", - return_value={"clickhouse": {"id": "string", "name": "string"}}, + return_value={ + "id": {"clickhouse": "string", "hogql": "StringDatabaseField"}, + "name": {"clickhouse": "string", "hogql": "StringDatabaseField"}, + }, ), mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), @@ -647,7 +693,7 @@ def mock_to_object_store_rs_credentials(class_self): folder_path = job_1.folder_path() job_1_customer_objects = minio_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=f"{folder_path}/customer/") - assert len(job_1_customer_objects["Contents"]) == 2 + assert len(job_1_customer_objects["Contents"]) == 3 job_1.refresh_from_db() assert job_1.rows_synced == 1 @@ -687,15 +733,28 @@ def mock_func(inputs): with ( mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns", return_value={"id": "string"}), - mock.patch.object(DataImportPipelineSync, "run", mock_func), + mock.patch.object(PipelineNonDLT, "run", mock_func), ): - with override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, - AIRBYTE_BUCKET_REGION="us-east-1", - AIRBYTE_BUCKET_DOMAIN="objectstorage:19000", - BUCKET_NAME=BUCKET_NAME, + with ( + override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + AIRBYTE_BUCKET_REGION="us-east-1", + AIRBYTE_BUCKET_DOMAIN="objectstorage:19000", + BUCKET_NAME=BUCKET_NAME, + ), + mock.patch.dict( # Mock os.environ for the deltalake subprocess + os.environ, + { + "BUCKET_URL": f"s3://{BUCKET_NAME}", + "AIRBYTE_BUCKET_KEY": settings.OBJECT_STORAGE_ACCESS_KEY_ID, + "AIRBYTE_BUCKET_SECRET": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + "AIRBYTE_BUCKET_REGION": "us-east-1", + "BUCKET_NAME": BUCKET_NAME, + "AIRBYTE_BUCKET_DOMAIN": "objectstorage:19000", + }, + ), ): async with await WorkflowEnvironment.start_time_skipping() as activity_environment: async with Worker( @@ -812,6 +871,18 @@ def mock_to_object_store_rs_credentials(class_self): AIRBYTE_BUCKET_DOMAIN="objectstorage:19000", BUCKET_NAME=BUCKET_NAME, ), + # Mock os.environ for the deltalake subprocess + mock.patch.dict( + os.environ, + { + "BUCKET_URL": f"s3://{BUCKET_NAME}", + "AIRBYTE_BUCKET_KEY": settings.OBJECT_STORAGE_ACCESS_KEY_ID, + "AIRBYTE_BUCKET_SECRET": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + "AIRBYTE_BUCKET_REGION": "us-east-1", + "BUCKET_NAME": BUCKET_NAME, + "AIRBYTE_BUCKET_DOMAIN": "objectstorage:19000", + }, + ), mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), ): @@ -819,4 +890,4 @@ def mock_to_object_store_rs_credentials(class_self): folder_path = await sync_to_async(job_1.folder_path)() job_1_team_objects = minio_client.list_objects_v2(Bucket=BUCKET_NAME, Prefix=f"{folder_path}/posthog_test/") - assert len(job_1_team_objects["Contents"]) == 2 + assert len(job_1_team_objects["Contents"]) == 3 diff --git a/posthog/warehouse/models/external_data_job.py b/posthog/warehouse/models/external_data_job.py index b7bd910b54097..5f1332a270b4f 100644 --- a/posthog/warehouse/models/external_data_job.py +++ b/posthog/warehouse/models/external_data_job.py @@ -42,17 +42,9 @@ def folder_path(self) -> str: def url_pattern_by_schema(self, schema: str) -> str: if TEST: - if self.pipeline_version == ExternalDataJob.PipelineVersion.V1: - return ( - f"http://{settings.AIRBYTE_BUCKET_DOMAIN}/{settings.BUCKET}/{self.folder_path()}/{schema.lower()}/" - ) - else: - return f"http://{settings.AIRBYTE_BUCKET_DOMAIN}/{settings.BUCKET}/{self.folder_path()}/{schema.lower()}__v2/" + return f"http://{settings.AIRBYTE_BUCKET_DOMAIN}/{settings.BUCKET}/{self.folder_path()}/{schema.lower()}/" - if self.pipeline_version == ExternalDataJob.PipelineVersion.V1: - return f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{self.folder_path()}/{schema.lower()}/" - - return f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{self.folder_path()}/{schema.lower()}__v2/" + return f"https://{settings.AIRBYTE_BUCKET_DOMAIN}/dlt/{self.folder_path()}/{schema.lower()}/" @database_sync_to_async diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index d90b699a7f4e0..843dad9b09420 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -8,7 +8,6 @@ import numpy import snowflake.connector from django.conf import settings -from posthog.constants import DATA_WAREHOUSE_TASK_QUEUE_V2 from posthog.models.team import Team from posthog.models.utils import CreatedMetaFields, DeletedMetaFields, UUIDModel, UpdatedMetaFields, sane_repr import uuid @@ -103,12 +102,7 @@ def update_incremental_field_last_value(self, last_value: Any) -> None: else: last_value_json = str(last_value_py) - if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2: - key = "incremental_field_last_value_v2" - else: - key = "incremental_field_last_value" - - self.sync_type_config[key] = last_value_json + self.sync_type_config["incremental_field_last_value"] = last_value_json self.save() def soft_delete(self): diff --git a/posthog/warehouse/models/table.py b/posthog/warehouse/models/table.py index 0f960d2648c8d..cfb7926535e24 100644 --- a/posthog/warehouse/models/table.py +++ b/posthog/warehouse/models/table.py @@ -30,7 +30,7 @@ from posthog.hogql.context import HogQLContext if TYPE_CHECKING: - from posthog.warehouse.models import ExternalDataJob + pass SERIALIZED_FIELD_TO_CLICKHOUSE_MAPPING: dict[DatabaseSerializedFieldType, str] = { DatabaseSerializedFieldType.INTEGER: "Int64", @@ -143,7 +143,6 @@ def validate_column_type(self, column_key) -> bool: def get_columns( self, - pipeline_version: Optional["ExternalDataJob.PipelineVersion"] = None, safe_expose_ch_error: bool = True, ) -> DataWarehouseTableColumns: try: @@ -154,7 +153,6 @@ def get_columns( access_key=self.credential.access_key, access_secret=self.credential.access_secret, context=placeholder_context, - pipeline_version=pipeline_version, ) result = sync_execute(