diff --git a/egg/nest/wrappers.py b/egg/nest/wrappers.py index 7f966b465..ace4d02e6 100644 --- a/egg/nest/wrappers.py +++ b/egg/nest/wrappers.py @@ -40,18 +40,28 @@ def __init__(self, runnable, log_dir, job_id): def __call__(self, args): stdout_path = pathlib.Path(self.log_dir) / f"{self.job_id}.out" - self.stdout = open(stdout_path, "w") - stderr_path = pathlib.Path(self.log_dir) / f"{self.job_id}.err" - self.stderr = open(stderr_path, "w") - - sys.stdout = self.stdout - sys.stderr = self.stderr - cuda_id = -1 - n_devices = torch.cuda.device_count() - if n_devices > 0: - cuda_id = self.job_id % n_devices - print(f"# {json.dumps(args)}", flush=True) - - with torch.cuda.device(cuda_id): - self.runnable(args) + + with open(stdout_path, "w") as self.stdout, open( + stderr_path, "w" + ) as self.stderr: + original_stdout = sys.stdout + original_stderr = sys.stderr + sys.stdout = self.stdout + sys.stderr = self.stderr + + cuda_id = -1 + n_devices = torch.cuda.device_count() + if n_devices > 0: + cuda_id = self.job_id % n_devices + + print(f"# {json.dumps(args)}", flush=True) + + with torch.cuda.device(cuda_id): + self.runnable(args) + + sys.stdout.flush() + sys.stderr.flush() + + sys.stdout = original_stdout + sys.stderr = original_stderr diff --git a/tests/test_concurrent_wrapper.py b/tests/test_concurrent_wrapper.py new file mode 100644 index 000000000..d19740f2d --- /dev/null +++ b/tests/test_concurrent_wrapper.py @@ -0,0 +1,76 @@ +import json +import multiprocessing +import pathlib +import sys +import time + +import pytest + +from egg.nest.wrappers import ConcurrentWrapper + +multiprocessing.set_start_method( + "spawn", force=True +) # avoiding issue with CUDA re-initialization in a forked subprocess + + +def dummy_runnable(args): + print("Running dummy_runnable") + print(json.dumps(args), file=sys.stderr) + + +def test_file_descriptor_closure(tmp_path): + """ + Test to check if file descriptors are closed. + Attempting to write to a closed file should raise a ValueError + """ + runnable = dummy_runnable + log_dir = tmp_path + job_id = 1 + + wrapper = ConcurrentWrapper(runnable, log_dir, job_id) + wrapper({"key": "value"}) + + with pytest.raises(ValueError): + wrapper.stdout.write("This should fail if the file is closed.") + + with pytest.raises(ValueError): + wrapper.stderr.write("This should fail if the file is closed.") + + +def test_stdout_stderr_restoration(tmp_path): + """Test to ensure sys.stdout and sys.stderr are restored""" + original_stdout = sys.stdout + original_stderr = sys.stderr + + runnable = dummy_runnable + log_dir = tmp_path + job_id = 2 + + wrapper = ConcurrentWrapper(runnable, log_dir, job_id) + wrapper({"another_key": "another_value"}) + + assert sys.stdout == original_stdout + assert sys.stderr == original_stderr + + +def delayed_print_runnable(args): + print("This is a test.") + time.sleep(0.1) # Introduce a slight delay + + +def test_delayed_output_capture(tmp_path): + log_dir = tmp_path + job_id = 1 + + runner = ConcurrentWrapper( + runnable=delayed_print_runnable, log_dir=log_dir, job_id=job_id + ) + + runner([]) + + stdout_path = pathlib.Path(log_dir) / f"{job_id}.out" + + with open(stdout_path, "r") as f: + output = f.read() + + assert "This is a test." in output, "Expected output was not captured in the file."