diff --git a/database/interface.py b/database/interface.py index c261cb4f..c6b56559 100644 --- a/database/interface.py +++ b/database/interface.py @@ -24,6 +24,9 @@ PAGINATE_ENTRIES_PER_PAGE = 20 PAGINATE_START_PAGE = 0 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger() + def paginate(fn): @wraps(fn) @@ -55,16 +58,25 @@ def _impl(self, *args, **kwargs): count_q = query.statement.with_only_columns( func.count(), maintain_column_froms=True ).order_by(None) - count = query.session.execute(count_q).scalar() - return count + return query.session.execute(count_q).scalar() else: return query return _impl -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger() +def count_wrapper(fn): + """A wrapper that enables non-paginated functions to use the count decorater""" + + @wraps(fn) + def _impl(self, *args, **kwargs): + query = fn(self, *args, **kwargs) + if kwargs.get("count") is True: + return query + else: + return query.all() + + return _impl class HarvesterDBInterface: @@ -340,17 +352,17 @@ def get_harvest_job_errors_by_job(self, job_id: str) -> list[dict]: @paginate def get_harvest_record_errors_by_job(self, job_id: str, **kwargs): """ - Retrieves harvest record errors for a given job. + Retrieves harvest record errors for a given job. - This function fetches all records where the harvest status is 'error' and - belongs to the specified job. The query returns a tuple containing: - - HarvestRecordError object - - identifier (retrieved from HarvestRecord) - - source_raw (retrieved from HarvestRecord, containing 'title') + This function fetches all records where the harvest status is 'error' and + belongs to the specified job. The query returns a tuple containing: + - HarvestRecordError object + - identifier (retrieved from HarvestRecord) + - source_raw (retrieved from HarvestRecord, containing 'title') - Returns: - Query: A SQLAlchemy Query object that, when executed, yields tuples of: - (HarvestRecordError, identifier, source_raw). + Returns: + Query: A SQLAlchemy Query object that, when executed, yields tuples of: + (HarvestRecordError, identifier, source_raw). """ subquery = ( self.db.query(HarvestRecord.id) @@ -360,12 +372,11 @@ def get_harvest_record_errors_by_job(self, job_id: str, **kwargs): ) query = ( self.db.query( - HarvestRecordError, - HarvestRecord.identifier, - HarvestRecord.source_raw + HarvestRecordError, HarvestRecord.identifier, HarvestRecord.source_raw + ) + .join( + HarvestRecord, HarvestRecord.id == HarvestRecordError.harvest_record_id ) - .join(HarvestRecord, - HarvestRecord.id == HarvestRecordError.harvest_record_id) .filter(HarvestRecord.id.in_(select(subquery))) ) return query @@ -506,7 +517,9 @@ def get_all_outdated_records(self, days=365): return old_records_query.except_all(latest_successful_records_query).all() - def get_latest_harvest_records_by_source(self, source_id): + @count_wrapper + @count + def get_latest_harvest_records_by_source_orm(self, source_id, **kwargs): # datetimes are returned as datetime objs not strs subq = ( self.db.query(HarvestRecord) @@ -520,11 +533,10 @@ def get_latest_harvest_records_by_source(self, source_id): ) sq_alias = aliased(HarvestRecord, subq) - query = self.db.query(sq_alias).filter(sq_alias.action != "delete") + return self.db.query(sq_alias).filter(sq_alias.action != "delete") - records = query.all() - - return self._to_dict(records) + def get_latest_harvest_records_by_source(self, source_id): + return self._to_dict(self.get_latest_harvest_records_by_source_orm(source_id)) def close(self): if hasattr(self.db, "remove"): diff --git a/tests/badges/integration/tests.svg b/tests/badges/integration/tests.svg index dba5dd78..2bbe0c62 100644 --- a/tests/badges/integration/tests.svg +++ b/tests/badges/integration/tests.svg @@ -5,7 +5,7 @@ width="62.5" height="20" role="img" - aria-label="tests: 70" + aria-label="tests: 72" > - tests: 70 + tests: 72 @@ -42,8 +42,8 @@ - 70 - 70 + 72 + 72 diff --git a/tests/integration/database/test_db.py b/tests/integration/database/test_db.py index ec3f57b2..ed1181b6 100644 --- a/tests/integration/database/test_db.py +++ b/tests/integration/database/test_db.py @@ -279,7 +279,27 @@ def test_endpoint_count( job_id, count=True, ) - assert count == len(record_data_dcatus) + assert count == len(record_data_dcatus) == 10 + + def test_endpoint_count_for_non_paginated_methods( + self, interface_with_fixture_json, source_data_dcatus, record_data_dcatus + ): + interface = interface_with_fixture_json + count = interface.get_latest_harvest_records_by_source_orm( + source_data_dcatus["id"], + count=True, + ) + assert ( + count + == len( + [ + record + for record in record_data_dcatus + if record["status"] == "success" + ] + ) + == 0 + ) def test_errors_by_job( self, @@ -555,7 +575,7 @@ def test_delete_outdated_records_and_errors( assert len(hs2_outdated.errors) == 1 for record in outdated_records: - interface.delete_harvest_record(ID=record.id) + interface.delete_harvest_record(record_id=record.id) # make sure only the outdated records and associated errors were deleted db_records = interface.pget_harvest_records(count=True) diff --git a/tests/integration/harvest_job_flows/test_harvest_job_full_flow.py b/tests/integration/harvest_job_flows/test_harvest_job_full_flow.py index 9041a692..b9658187 100644 --- a/tests/integration/harvest_job_flows/test_harvest_job_full_flow.py +++ b/tests/integration/harvest_job_flows/test_harvest_job_full_flow.py @@ -90,7 +90,12 @@ def test_multiple_harvest_jobs( source_data_dcatus["id"] ) - assert len(records_from_db) == 3 + records_from_db_count = interface.get_latest_harvest_records_by_source_orm( + source_data_dcatus["id"], + count=True, + ) + + assert len(records_from_db) == records_from_db_count == 7 @patch("harvester.harvest.ckan") @patch("harvester.harvest.download_file") @@ -124,7 +129,9 @@ def test_harvest_record_errors_reported( assert harvest_job.status == "complete" assert len(interface_errors) == harvest_job.records_errored assert len(interface_errors) == len(job_errors) - assert interface_errors[0][0].harvest_record_id == job_errors[0].harvest_record_id + assert ( + interface_errors[0][0].harvest_record_id == job_errors[0].harvest_record_id + ) @patch("harvester.harvest.ckan") @patch("harvester.utils.ckan_utils.uuid")