diff --git a/database/interface.py b/database/interface.py index c261cb4f..c6b56559 100644 --- a/database/interface.py +++ b/database/interface.py @@ -24,6 +24,9 @@ PAGINATE_ENTRIES_PER_PAGE = 20 PAGINATE_START_PAGE = 0 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger() + def paginate(fn): @wraps(fn) @@ -55,16 +58,25 @@ def _impl(self, *args, **kwargs): count_q = query.statement.with_only_columns( func.count(), maintain_column_froms=True ).order_by(None) - count = query.session.execute(count_q).scalar() - return count + return query.session.execute(count_q).scalar() else: return query return _impl -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger() +def count_wrapper(fn): + """A wrapper that enables non-paginated functions to use the count decorater""" + + @wraps(fn) + def _impl(self, *args, **kwargs): + query = fn(self, *args, **kwargs) + if kwargs.get("count") is True: + return query + else: + return query.all() + + return _impl class HarvesterDBInterface: @@ -340,17 +352,17 @@ def get_harvest_job_errors_by_job(self, job_id: str) -> list[dict]: @paginate def get_harvest_record_errors_by_job(self, job_id: str, **kwargs): """ - Retrieves harvest record errors for a given job. + Retrieves harvest record errors for a given job. - This function fetches all records where the harvest status is 'error' and - belongs to the specified job. The query returns a tuple containing: - - HarvestRecordError object - - identifier (retrieved from HarvestRecord) - - source_raw (retrieved from HarvestRecord, containing 'title') + This function fetches all records where the harvest status is 'error' and + belongs to the specified job. The query returns a tuple containing: + - HarvestRecordError object + - identifier (retrieved from HarvestRecord) + - source_raw (retrieved from HarvestRecord, containing 'title') - Returns: - Query: A SQLAlchemy Query object that, when executed, yields tuples of: - (HarvestRecordError, identifier, source_raw). + Returns: + Query: A SQLAlchemy Query object that, when executed, yields tuples of: + (HarvestRecordError, identifier, source_raw). """ subquery = ( self.db.query(HarvestRecord.id) @@ -360,12 +372,11 @@ def get_harvest_record_errors_by_job(self, job_id: str, **kwargs): ) query = ( self.db.query( - HarvestRecordError, - HarvestRecord.identifier, - HarvestRecord.source_raw + HarvestRecordError, HarvestRecord.identifier, HarvestRecord.source_raw + ) + .join( + HarvestRecord, HarvestRecord.id == HarvestRecordError.harvest_record_id ) - .join(HarvestRecord, - HarvestRecord.id == HarvestRecordError.harvest_record_id) .filter(HarvestRecord.id.in_(select(subquery))) ) return query @@ -506,7 +517,9 @@ def get_all_outdated_records(self, days=365): return old_records_query.except_all(latest_successful_records_query).all() - def get_latest_harvest_records_by_source(self, source_id): + @count_wrapper + @count + def get_latest_harvest_records_by_source_orm(self, source_id, **kwargs): # datetimes are returned as datetime objs not strs subq = ( self.db.query(HarvestRecord) @@ -520,11 +533,10 @@ def get_latest_harvest_records_by_source(self, source_id): ) sq_alias = aliased(HarvestRecord, subq) - query = self.db.query(sq_alias).filter(sq_alias.action != "delete") + return self.db.query(sq_alias).filter(sq_alias.action != "delete") - records = query.all() - - return self._to_dict(records) + def get_latest_harvest_records_by_source(self, source_id): + return self._to_dict(self.get_latest_harvest_records_by_source_orm(source_id)) def close(self): if hasattr(self.db, "remove"): diff --git a/tests/badges/integration/tests.svg b/tests/badges/integration/tests.svg index dba5dd78..2bbe0c62 100644 --- a/tests/badges/integration/tests.svg +++ b/tests/badges/integration/tests.svg @@ -5,7 +5,7 @@ width="62.5" height="20" role="img" - aria-label="tests: 70" + aria-label="tests: 72" > -