From 72bc1d16e82cf031fcb84a88553e2e3eada5a5ba Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Sun, 22 Dec 2024 09:03:04 -0500 Subject: [PATCH] Close processes at end of Execute.run method, closes #57 --- src/python/paperetl/file/execute.py | 79 ++++++++++++++++++----------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/src/python/paperetl/file/execute.py b/src/python/paperetl/file/execute.py index 3cb3a37..08d4573 100644 --- a/src/python/paperetl/file/execute.py +++ b/src/python/paperetl/file/execute.py @@ -37,11 +37,7 @@ def mode(source, extension): file open mode """ - return ( - "rb" - if extension == "pdf" or (source and source.lower().startswith("pubmed")) - else "r" - ) + return "rb" if extension == "pdf" or (source and source.lower().startswith("pubmed")) else "r" @staticmethod def parse(path, source, extension, compress, config): @@ -60,9 +56,7 @@ def parse(path, source, extension, compress, config): # Determine if file needs to be open in binary or text mode mode = Execute.mode(source, extension) - with gzip.open(path, mode) if compress else open( - path, mode, encoding="utf-8" if mode == "r" else None - ) as stream: + with gzip.open(path, mode) if compress else open(path, mode, encoding="utf-8" if mode == "r" else None) as stream: if extension == "pdf": yield PDF.parse(stream, source) elif extension == "xml": @@ -119,9 +113,7 @@ def scan(indir, config, inputs): for f in sorted(files): # Extract file extension parts = f.lower().split(".") - extension, compress = ( - (parts[-2], True) if parts[-1] == "gz" else (parts[-1], False) - ) + extension, compress = (parts[-2], True) if parts[-1] == "gz" else (parts[-1], False) # Check if file ends with accepted extension if any(extension for ext in ["csv", "pdf", "xml"] if ext == extension): @@ -160,6 +152,26 @@ def save(processes, outputs, db): elif result: db.save(result) + @staticmethod + def close(processes, inputs, outputs): + """ + Closes open processes and queues. + + Args: + processes: list of processes + inputs: input queue + outputs: output queue + """ + + if processes: + # Close processes + for process in processes: + process.close() + + # Close queues + inputs.close() + outputs.close() + @staticmethod def run(indir, url, config=None, replace=False): """ @@ -172,29 +184,34 @@ def run(indir, url, config=None, replace=False): replace: if true, a new database will be created, overwriting any existing database """ - # Build database connection - db = Factory.create(url, replace) + processes, inputs, outputs = None, None, None + try: + # Build database connection + db = Factory.create(url, replace) - # Create queues, limit size of output queue - inputs, outputs = Queue(), Queue(30000) + # Create queues, limit size of output queue + inputs, outputs = Queue(), Queue(30000) - # Scan input directory and add files to inputs queue - total = Execute.scan(indir, config, inputs) + # Scan input directory and add files to inputs queue + total = Execute.scan(indir, config, inputs) - # Start worker processes - processes = [] - for _ in range(min(total, os.cpu_count())): - process = Process(target=Execute.process, args=(inputs, outputs)) - process.start() - processes.append(process) + # Start worker processes + processes = [] + for _ in range(min(total, os.cpu_count())): + process = Process(target=Execute.process, args=(inputs, outputs)) + process.start() + processes.append(process) - # Read results from worker processes and save to database - Execute.save(processes, outputs, db) + # Read results from worker processes and save to database + Execute.save(processes, outputs, db) - # Complete and close database - db.complete() - db.close() + # Complete and close database + db.complete() + db.close() - # Wait for processes to terminate - for process in processes: - process.join() + # Wait for processes to terminate + for process in processes: + process.join() + + finally: + Execute.close(processes, inputs, outputs)