diff --git a/doc/troubleshooting.md b/doc/troubleshooting.md index 681556b..75a4b3a 100644 --- a/doc/troubleshooting.md +++ b/doc/troubleshooting.md @@ -3,7 +3,9 @@ This sections provides list of possible issues and solutions that may occur when using SB-OSC. ### apply_dml_events_validation_batch_size -When setting `apply_dml_events_validation_batch_size` there are two factors to consider. Since the binlog resolution is in seconds, if the number of DML events in a second is greater than the batch size, the validation process can hang indefinitely. In this case, it is recommended to increase the batch size. +~~When setting `apply_dml_events_validation_batch_size` there are two factors to consider. Since the binlog resolution is in seconds, if the number of DML events in a second is greater than the batch size, the validation process can hang indefinitely. In this case, it is recommended to increase the batch size.~~ +-> This issue was fixed by [#10](https://github.com/sendbird/sb-osc/pull/10) + Another factor is `max_allowed_packet` of MySQL. Apply DML events stage uses query with IN clause containing `apply_dml_events_validation_batch_size` number of PKs. If the size of this query exceeds `max_allowed_packet`, the query will not return properly. In this case, it is recommended to decrease the batch size. Also, you might need to kill running queries since it may hang indefinitely in this case. diff --git a/src/sbosc/controller/controller.py b/src/sbosc/controller/controller.py index 84ec8bb..2ad8453 100644 --- a/src/sbosc/controller/controller.py +++ b/src/sbosc/controller/controller.py @@ -178,12 +178,12 @@ def apply_dml_events_validation(self): is_valid = self.validator.apply_dml_events_validation() if is_valid: - # Analyze table - with self.db.cursor(host='dest') as cursor: - cursor: Cursor - metadata = self.redis_data.metadata - cursor.execute(f"ANALYZE TABLE {metadata.destination_db}.{metadata.destination_table}") - self.logger.info("Finished ANALYZE TABLE on destination table") + full_dml_event_validation_executed = self.validator.full_dml_event_validation() + if full_dml_event_validation_executed: # Validation did not skip + # Returning will call apply_dml_events_validation again + # full_dml_event_validation may take a long time + # So, apply_dml_events_validation needs to be called again to validate the latest DML events + return if not self.is_preferred_window(): self.logger.info("Waiting for preferred window") @@ -194,9 +194,12 @@ def apply_dml_events_validation(self): time.sleep(config.WAIT_INTERVAL_UNTIL_AUTO_SWAP_IN_SECONDS) return - is_valid = self.validator.full_dml_event_validation() - if is_valid is not None: # Validation did not skip - return + # Analyze table + with self.db.cursor(host='dest') as cursor: + cursor: Cursor + metadata = self.redis_data.metadata + cursor.execute(f"ANALYZE TABLE {metadata.destination_db}.{metadata.destination_table}") + self.logger.info("Finished ANALYZE TABLE on destination table") self.redis_data.set_current_stage(Stage.SWAP_TABLES) self.interval = 1 diff --git a/src/sbosc/controller/validator.py b/src/sbosc/controller/validator.py index 8b6eef6..f790d82 100644 --- a/src/sbosc/controller/validator.py +++ b/src/sbosc/controller/validator.py @@ -6,7 +6,7 @@ import MySQLdb from MySQLdb.cursors import Cursor -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generator, List from modules.db import Database from sbosc.exceptions import StopFlagSet @@ -39,7 +39,7 @@ def __init__(self, controller: 'Controller'): def set_stop_flag(self): self.stop_flag = True - def handle_operational_error(self, e, range_queue, start_range, end_range): + def __handle_operational_error(self, e, range_queue, start_range, end_range): if e.args[0] == 2013: self.logger.warning("Query timeout. Retry with smaller batch size") range_queue.put((start_range, start_range + (end_range - start_range) // 2)) @@ -50,7 +50,7 @@ def handle_operational_error(self, e, range_queue, start_range, end_range): range_queue.put((start_range, end_range)) time.sleep(3) - def validate_bulk_import_batch(self, range_queue: Queue, failed_pks): + def __validate_bulk_import_batch(self, range_queue: Queue, failed_pks): with self.source_conn_pool.get_connection() as source_conn, self.dest_conn_pool.get_connection() as dest_conn: while not range_queue.empty(): if len(failed_pks) > 0: @@ -68,7 +68,7 @@ def validate_bulk_import_batch(self, range_queue: Queue, failed_pks): failed_pks.extend(not_imported_pks) return False except MySQLdb.OperationalError as e: - self.handle_operational_error(e, range_queue, batch_start_pk, batch_end_pk) + self.__handle_operational_error(e, range_queue, batch_start_pk, batch_end_pk) source_conn.ping(True) dest_conn.ping(True) continue @@ -83,7 +83,7 @@ def bulk_import_validation(self): metadata = self.redis_data.metadata range_queue = Queue() start_pk = 0 - while start_pk < metadata.max_id: + while start_pk <= metadata.max_id: range_queue.put((start_pk, min(start_pk + self.bulk_import_batch_size, metadata.max_id))) start_pk += self.bulk_import_batch_size + 1 failed_pks = [] @@ -91,7 +91,7 @@ def bulk_import_validation(self): with concurrent.futures.ThreadPoolExecutor(max_workers=self.thread_count) as executor: threads = [] for _ in range(self.thread_count): - threads.append(executor.submit(self.validate_bulk_import_batch, range_queue, failed_pks)) + threads.append(executor.submit(self.__validate_bulk_import_batch, range_queue, failed_pks)) is_valid = all([thread.result() for thread in threads]) if not is_valid: self.logger.critical(f"Failed to validate bulk import. Failed pks: {failed_pks}") @@ -99,7 +99,7 @@ def bulk_import_validation(self): self.logger.info("Bulk import validation succeeded") return is_valid - def get_timestamp_range(self): + def __get_timestamp_range(self): start_timestamp = None end_timestamp = None with self.db.cursor() as cursor: @@ -126,6 +126,8 @@ def get_timestamp_range(self): if cursor.rowcount > 0: start_timestamp = cursor.fetchone()[0] + # This ensures that all events up to the last event timestamp are all saved in the event tables + # save_current_binlog_position are called after save_events_to_db cursor.execute(f''' SELECT last_event_timestamp FROM {config.SBOSC_DB}.event_handler_status WHERE migration_id = {self.migration_id} ORDER BY id DESC LIMIT 1 @@ -134,45 +136,49 @@ def get_timestamp_range(self): end_timestamp = cursor.fetchone()[0] return start_timestamp, end_timestamp - def execute_apply_dml_events_validation_query( - self, source_cursor, dest_cursor, table, start_timestamp, end_timestamp, unmatched_pks): + def __execute_apply_dml_events_validation_query( + self, source_cursor, dest_cursor, table, event_pks: list, unmatched_pks: list): metadata = self.redis_data.metadata if table == 'inserted_pk': - not_inserted_pks = self.migration_operation.get_not_inserted_pks( - source_cursor, dest_cursor, start_timestamp, end_timestamp) + not_inserted_pks = self.migration_operation.get_not_inserted_pks(source_cursor, dest_cursor, event_pks) if not_inserted_pks: - self.logger.warning(f"Found {len(not_inserted_pks)} unmatched inserted pks") + self.logger.warning(f"Found {len(not_inserted_pks)} unmatched inserted pks: {not_inserted_pks}") unmatched_pks.extend([(pk, UnmatchType.NOT_UPDATED) for pk in not_inserted_pks]) elif table == 'updated_pk': - not_updated_pks = self.migration_operation.get_not_updated_pks( - source_cursor, dest_cursor, start_timestamp, end_timestamp) + not_updated_pks = self.migration_operation.get_not_updated_pks(source_cursor, dest_cursor, event_pks) if not_updated_pks: - self.logger.warning(f"Found {len(not_updated_pks)} unmatched updated pks") + self.logger.warning(f"Found {len(not_updated_pks)} unmatched updated pks: {not_updated_pks}") unmatched_pks.extend([(pk, UnmatchType.NOT_UPDATED) for pk in not_updated_pks]) elif table == 'deleted_pk': - source_cursor.execute(f''' - SELECT source_pk FROM {config.SBOSC_DB}.deleted_pk_{self.migration_id} - WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp} - ''') - if source_cursor.rowcount > 0: - target_pks = ','.join([str(row[0]) for row in source_cursor.fetchall()]) + if event_pks: + event_pks_str = ','.join([str(pk) for pk in event_pks]) dest_cursor.execute(f''' - SELECT id FROM {metadata.destination_db}.{metadata.destination_table} WHERE id IN ({target_pks}) + SELECT id FROM {metadata.destination_db}.{metadata.destination_table} WHERE id IN ({event_pks_str}) ''') - deleted_pks = set([row[0] for row in dest_cursor.fetchall()]) + not_deleted_pks = set([row[0] for row in dest_cursor.fetchall()]) if dest_cursor.rowcount > 0: # Check if deleted pks are reinserted source_cursor.execute(f''' - SELECT id FROM {metadata.source_db}.{metadata.source_table} WHERE id IN ({target_pks}) + SELECT id FROM {metadata.source_db}.{metadata.source_table} WHERE id IN ({event_pks_str}) ''') reinserted_pks = set([row[0] for row in source_cursor.fetchall()]) if reinserted_pks: - deleted_pks = deleted_pks - reinserted_pks - self.logger.warning(f"Found {len(reinserted_pks)} reinserted pks") - self.logger.warning(f"Found {len(deleted_pks)} unmatched deleted pks") - unmatched_pks.extend([(pk, UnmatchType.NOT_REMOVED) for pk in deleted_pks]) - - def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_pks): + not_deleted_pks = not_deleted_pks - reinserted_pks + self.logger.warning(f"Found {len(reinserted_pks)} reinserted pks: {reinserted_pks}") + self.logger.warning(f"Found {len(not_deleted_pks)} unmatched deleted pks: {not_deleted_pks}") + unmatched_pks.extend([(pk, UnmatchType.NOT_REMOVED) for pk in not_deleted_pks]) + + def __get_event_pk_batch(self, cursor, table, start_timestamp, end_timestamp) -> Generator[List[int], None, None]: + cursor.execute(f''' + SELECT source_pk FROM {config.SBOSC_DB}.{table}_{self.migration_id} + WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp} + ''') + event_pks = [row[0] for row in cursor.fetchall()] + while event_pks: + yield event_pks[:self.apply_dml_events_batch_size] + event_pks = event_pks[self.apply_dml_events_batch_size:] + + def __validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_pks): with self.source_conn_pool.get_connection() as source_conn, self.dest_conn_pool.get_connection() as dest_conn: while not range_queue.empty(): if self.stop_flag: @@ -180,6 +186,7 @@ def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_p try: batch_start_timestamp, batch_end_timestamp = range_queue.get_nowait() + self.logger.info(f"Validating {table} from {batch_start_timestamp} to {batch_end_timestamp}") except Empty: self.logger.warning("Range queue is empty") continue @@ -195,7 +202,7 @@ def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_p WHERE event_timestamp BETWEEN {batch_start_timestamp} AND {batch_end_timestamp} ''') event_count = source_cursor.fetchone()[0] - if event_count > self.apply_dml_events_batch_size: + if event_count > self.apply_dml_events_batch_size and batch_end_timestamp > batch_start_timestamp: range_queue.put(( batch_start_timestamp, batch_start_timestamp + (batch_end_timestamp - batch_start_timestamp) // 2 @@ -208,17 +215,20 @@ def validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched_p else: try: - self.execute_apply_dml_events_validation_query( - source_cursor, dest_cursor, table, - batch_start_timestamp, batch_end_timestamp, unmatched_pks + event_pk_batch = self.__get_event_pk_batch( + source_cursor, table, batch_start_timestamp, batch_end_timestamp ) + while event_pks := next(event_pk_batch, None): + self.__execute_apply_dml_events_validation_query( + source_cursor, dest_cursor, table, event_pks, unmatched_pks + ) except MySQLdb.OperationalError as e: - self.handle_operational_error(e, range_queue, batch_start_timestamp, batch_end_timestamp) + self.__handle_operational_error(e, range_queue, batch_start_timestamp, batch_end_timestamp) source_conn.ping(True) dest_conn.ping(True) continue - def validate_unmatched_pks(self): + def __validate_unmatched_pks(self): self.logger.info("Validating unmatched pks") with self.db.cursor() as cursor: cursor: Cursor @@ -276,7 +286,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp): if table_rows > 0: range_queue = Queue() batch_start_timestamp = start_timestamp - while batch_start_timestamp < end_timestamp: + while batch_start_timestamp <= end_timestamp: batch_duration = \ (end_timestamp - start_timestamp) * self.apply_dml_events_batch_size // table_rows batch_end_timestamp = min(batch_start_timestamp + batch_duration, end_timestamp) @@ -287,7 +297,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp): threads = [] for _ in range(self.thread_count): threads.append(executor.submit( - self.validate_apply_dml_events_batch, table, range_queue, unmatched_pks)) + self.__validate_apply_dml_events_batch, table, range_queue, unmatched_pks)) for thread in threads: thread.result() @@ -295,7 +305,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp): INSERT IGNORE INTO {config.SBOSC_DB}.unmatched_rows (source_pk, migration_id, unmatch_type) VALUES (%s, {self.migration_id}, %s) ''', unmatched_pks) - self.validate_unmatched_pks() + self.__validate_unmatched_pks() cursor.execute( f"SELECT COUNT(1) FROM {config.SBOSC_DB}.unmatched_rows WHERE migration_id = {self.migration_id}") unmatched_rows = cursor.fetchone()[0] @@ -307,7 +317,7 @@ def validate_apply_dml_events(self, start_timestamp, end_timestamp): def apply_dml_events_validation(self): self.logger.info("Start apply DML events validation") - start_timestamp, end_timestamp = self.get_timestamp_range() + start_timestamp, end_timestamp = self.__get_timestamp_range() if start_timestamp is None: self.logger.warning("No events found. Skipping apply DML events validation") return True @@ -329,8 +339,12 @@ def apply_dml_events_validation(self): def full_dml_event_validation(self): """ - :return: True if validation succeeded, False if validation failed, None if validation is skipped + :return: True if validation ran, False if validation skipped """ + if self.full_dml_event_validation_interval == 0: + self.logger.info("Full DML event validation is disabled") + return False + self.logger.info("Start full DML event validation") with self.db.cursor(role='reader') as cursor: @@ -344,8 +358,10 @@ def full_dml_event_validation(self): last_validation_time = cursor.fetchone()[0] if datetime.now() - last_validation_time < timedelta(hours=self.full_dml_event_validation_interval): self.logger.info( - "Last validation was done less than 1 hour ago. Skipping full DML event validation") - return + f"Last validation was done less than {self.full_dml_event_validation_interval} hour ago. " + f"Skipping full DML event validation" + ) + return False cursor.execute(f''' SELECT MIN(event_timestamps.min_ts) FROM ( @@ -358,7 +374,7 @@ def full_dml_event_validation(self): start_timestamp = cursor.fetchone()[0] if start_timestamp is None: self.logger.warning("No events found. Skipping full DML event validation") - return + return False cursor.execute(f''' SELECT last_event_timestamp FROM {config.SBOSC_DB}.event_handler_status @@ -368,7 +384,7 @@ def full_dml_event_validation(self): end_timestamp = cursor.fetchone()[0] if end_timestamp is None: self.logger.warning("Failed to get valid end_timestamp") - return + return False is_valid = self.validate_apply_dml_events(start_timestamp, end_timestamp) @@ -379,4 +395,4 @@ def full_dml_event_validation(self): VALUES ({self.migration_id}, {end_timestamp}, {is_valid}, NOW()) ''') - return is_valid + return True diff --git a/src/sbosc/operations/base.py b/src/sbosc/operations/base.py index e296ea8..3ad0ebc 100644 --- a/src/sbosc/operations/base.py +++ b/src/sbosc/operations/base.py @@ -54,33 +54,33 @@ def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk): not_imported_pks = [row[0] for row in source_cursor.fetchall()] return not_imported_pks - def get_not_inserted_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp): + def get_not_inserted_pks(self, source_cursor, dest_cursor, event_pks): not_inserted_pks = [] - event_pks = self._get_event_pks(source_cursor, 'insert', start_timestamp, end_timestamp) if event_pks: + event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' SELECT source.id FROM {self.source_db}.{self.source_table} AS source LEFT JOIN {self.destination_db}.{self.destination_table} AS dest ON source.id = dest.id - WHERE source.id IN ({event_pks}) + WHERE source.id IN ({event_pks_str}) AND dest.id IS NULL ''') not_inserted_pks = [row[0] for row in source_cursor.fetchall()] return not_inserted_pks - def get_not_updated_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp): + def get_not_updated_pks(self, source_cursor, dest_cursor, event_pks): not_updated_pks = [] - event_pks = self._get_event_pks(source_cursor, 'update', start_timestamp, end_timestamp) if event_pks: + event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' SELECT combined.id FROM ( SELECT {self.source_columns}, 'source' AS table_type FROM {self.source_db}.{self.source_table} - WHERE id IN ({event_pks}) + WHERE id IN ({event_pks_str}) UNION ALL SELECT {self.source_columns}, 'destination' AS table_type FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({event_pks}) + WHERE id IN ({event_pks_str}) ) AS combined GROUP BY {self.source_columns} HAVING COUNT(1) = 1 AND SUM(table_type = 'source') = 1 @@ -190,30 +190,30 @@ def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk): dest_pks = [row[0] for row in dest_cursor.fetchall()] return list(set(source_pks) - set(dest_pks)) - def get_not_inserted_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp): + def get_not_inserted_pks(self, source_cursor, dest_cursor, event_pks): not_inserted_pks = [] - event_pks = self._get_event_pks(source_cursor, 'insert', start_timestamp, end_timestamp) if event_pks: - source_cursor.execute(f"SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({event_pks})") + event_pks_str = ','.join([str(pk) for pk in event_pks]) + source_cursor.execute(f"SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({event_pks_str})") source_pks = [row[0] for row in source_cursor.fetchall()] dest_cursor.execute( - f"SELECT id FROM {self.destination_db}.{self.destination_table} WHERE id IN ({event_pks})") + f"SELECT id FROM {self.destination_db}.{self.destination_table} WHERE id IN ({event_pks_str})") dest_pks = [row[0] for row in dest_cursor.fetchall()] not_inserted_pks = list(set(source_pks) - set(dest_pks)) return not_inserted_pks - def get_not_updated_pks(self, source_cursor, dest_cursor, start_timestamp, end_timestamp): + def get_not_updated_pks(self, source_cursor, dest_cursor, event_pks): not_updated_pks = [] - event_pks = self._get_event_pks(source_cursor, 'update', start_timestamp, end_timestamp) if event_pks: + event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id IN ({event_pks}) + WHERE id IN ({event_pks_str}) ''') source_df = pd.DataFrame(source_cursor.fetchall(), columns=[c[0] for c in source_cursor.description]) dest_cursor.execute(f''' SELECT {self.source_columns} FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({event_pks}) + WHERE id IN ({event_pks_str}) ''') dest_df = pd.DataFrame(dest_cursor.fetchall(), columns=[c[0] for c in dest_cursor.description]) @@ -230,24 +230,27 @@ def get_not_updated_pks(self, source_cursor, dest_cursor, start_timestamp, end_t def get_rematched_updated_pks(self, db, not_updated_pks): not_updated_pks_str = ','.join([str(pk) for pk in not_updated_pks]) - with db.cursor(host='source', role='reader') as cursor: - cursor: Cursor - cursor.execute(f''' - SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id IN ({not_updated_pks_str}) - ''') - source_df = pd.DataFrame(cursor.fetchall(), columns=[c[0] for c in cursor.description]) - with db.cursor(host='dest', role='reader') as cursor: - cursor: Cursor - cursor.execute(f''' - SELECT {self.source_columns} FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({not_updated_pks_str}) - ''') - dest_df = pd.DataFrame(cursor.fetchall(), columns=[c[0] for c in cursor.description]) - - dest_df = dest_df.astype(source_df.dtypes.to_dict()) - merged_df = source_df.merge(dest_df, how='inner', on=source_df.columns.tolist(), indicator=True) - rematched_pks = set(merged_df[merged_df['_merge'] == 'both']['id'].tolist()) + # Get rematched_pks + try: + with db.cursor(host='source', role='reader') as cursor: + cursor: Cursor + cursor.execute(f''' + SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} + WHERE id IN ({not_updated_pks_str}) + ''') + source_df = pd.DataFrame(cursor.fetchall(), columns=[c[0] for c in cursor.description]) + with db.cursor(host='dest', role='reader') as cursor: + cursor: Cursor + cursor.execute(f''' + SELECT {self.source_columns} FROM {self.destination_db}.{self.destination_table} + WHERE id IN ({not_updated_pks_str}) + ''') + dest_df = pd.DataFrame(cursor.fetchall(), columns=[c[0] for c in cursor.description]) + dest_df = dest_df.astype(source_df.dtypes.to_dict()) + merged_df = source_df.merge(dest_df, how='inner', on=source_df.columns.tolist(), indicator=True) + rematched_pks = set(merged_df[merged_df['_merge'] == 'both']['id'].tolist()) + except pd.errors.IntCastingNaNError: + rematched_pks = set() # add deleted pks with db.cursor(host='source', role='reader') as cursor: cursor.execute(f''' diff --git a/src/sbosc/operations/operation.py b/src/sbosc/operations/operation.py index 7c0d21b..0f47f5f 100644 --- a/src/sbosc/operations/operation.py +++ b/src/sbosc/operations/operation.py @@ -1,10 +1,9 @@ from abc import abstractmethod from contextlib import contextmanager -from typing import Literal +from typing import List from MySQLdb.cursors import Cursor -from config import config from modules.db import Database from modules.redis import RedisData @@ -48,20 +47,8 @@ def get_not_imported_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start """ pass - def _get_event_pks( - self, cursor: Cursor, event_type: Literal['insert', 'update'], start_timestamp, end_timestamp): - table_names = { - 'insert': f'inserted_pk_{self.migration_id}', - 'update': f'updated_pk_{self.migration_id}' - } - cursor.execute(f''' - SELECT source_pk FROM {config.SBOSC_DB}.{table_names[event_type]} - WHERE event_timestamp BETWEEN {start_timestamp} AND {end_timestamp} - ''') - return ','.join([str(row[0]) for row in cursor.fetchall()]) - @abstractmethod - def get_not_inserted_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start_timestamp, end_timestamp): + def get_not_inserted_pks(self, source_cursor: Cursor, dest_cursor: Cursor, event_pks: List[int]): """ Returns a list of primary keys that have not been inserted into the destination table. Used in APPLY_DML_EVENTS_VALIDATION stage to validate that all inserts have been applied. @@ -69,7 +56,7 @@ def get_not_inserted_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start pass @abstractmethod - def get_not_updated_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start_timestamp, end_timestamp): + def get_not_updated_pks(self, source_cursor: Cursor, dest_cursor: Cursor, event_pks: List[int]): """ Returns a list of primary keys that have not been updated in the destination table. Used in APPLY_DML_EVENTS_VALIDATION stage to validate that all updates have been applied. diff --git a/tests/test_controller.py b/tests/test_controller.py index df1601a..11d56a5 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -189,10 +189,10 @@ def test_apply_dml_events_validation(controller: Controller, setup_table, redis_ cursor.execute(f"TRUNCATE TABLE {config.SBOSC_DB}.event_handler_status") cursor.execute(f"TRUNCATE TABLE {config.SBOSC_DB}.apply_dml_events_status") - # Event handler status doesn't have any row + # event_handler_status table doesn't have any row assert not controller.validator.apply_dml_events_validation() - # Insert row to event handler status and validate + # Insert row to event_handler_status table and validate cursor.execute(f''' INSERT INTO {config.SBOSC_DB}.event_handler_status (migration_id, log_file, log_pos, last_event_timestamp, created_at) VALUES (1, 'mysql-bin.000001', 4, {timestamp_range[1]}, NOW()) @@ -241,7 +241,7 @@ def test_apply_dml_events_validation(controller: Controller, setup_table, redis_ assert cursor.fetchone()[0] == 0 # Add new insert, update event - new_timestamp_range = (101, 200) + new_timestamp_range = (100, 200) new_insert_events = [ (random.randint(TABLE_SIZE, TABLE_SIZE * 2), random.randint(*new_timestamp_range)) for _ in range(500)] new_update_events = [(random.randint(1, TABLE_SIZE), random.randint(*new_timestamp_range)) for _ in range(500)] @@ -278,8 +278,36 @@ def test_apply_dml_events_validation(controller: Controller, setup_table, redis_ cursor.execute(f"SELECT COUNT(1) FROM {config.SBOSC_DB}.unmatched_rows") assert cursor.fetchone()[0] == 0 + # More records inserted than apply_dml_events_validation batch size in 1 second + large_insert_events = { + (random.randint(TABLE_SIZE * 2, TABLE_SIZE * 3), 201) for _ in range(2000)} + cursor.executemany(f''' + INSERT IGNORE INTO {config.SBOSC_DB}.inserted_pk_1 (source_pk, event_timestamp) VALUES (%s, %s) + ''', large_insert_events) + cursor.executemany(f''' + INSERT IGNORE INTO {config.SOURCE_DB}.{config.SOURCE_TABLE} (id, A, B, C) VALUES (%s, %s, %s, %s) + ''', [(i[0], 'a', 'b', 'c') for i in large_insert_events]) + cursor.execute(f''' + INSERT INTO {config.SBOSC_DB}.event_handler_status (migration_id, log_file, log_pos, last_event_timestamp, created_at) + VALUES (1, 'mysql-bin.000001', 4, 201, NOW()) + ''') + controller.validator.apply_dml_events_validation() + cursor.execute(f"SELECT COUNT(1) FROM {config.SBOSC_DB}.unmatched_rows") + assert cursor.fetchone()[0] == len(large_insert_events) + + # Apply changes to destination table + cursor.executemany(f''' + INSERT IGNORE INTO {config.DESTINATION_DB}.{config.DESTINATION_TABLE} (id, A, B, C) VALUES (%s, %s, %s, %s) + ''', [(i[0], 'a', 'b', 'c') for i in large_insert_events]) + + # requires 2 iterations to check all unmatched rows + controller.validator.apply_dml_events_validation() + assert controller.validator.apply_dml_events_validation() + # Test full validation assert controller.validator.full_dml_event_validation() + cursor.execute(f"SELECT is_valid FROM {config.SBOSC_DB}.full_dml_event_validation_status") + assert cursor.fetchone()[0] == 1 cursor.execute(f"USE {config.SBOSC_DB}") cursor.execute("TRUNCATE TABLE event_handler_status")