Skip to content

Commit

Permalink
client: torch: use old multiprocess API to support python 3.6
Browse files Browse the repository at this point in the history
EL8.8 has python 3.6 that does not have `initializer` argument in
ProcessPoolExecutor which makes it impossible to use due to forking and
needs of `daos_reinit` call.

This commit replace ProcessPoolExecutor API to its underlying multiprocess API.

Features: DfuseFind

Signed-off-by: Denis Barakhtanov <[email protected]>
  • Loading branch information
0xE0F committed Jan 20, 2025
1 parent 4e189fc commit d1e2f51
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 77 deletions.
171 changes: 95 additions & 76 deletions src/client/pydaos/torch/torch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
to access training data on DAOS DFS via POSIX container.
"""

import concurrent
import io
import math
import os
import stat
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor
from multiprocessing import Process, Queue

from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import IterableDataset as TorchIterableDataset
Expand Down Expand Up @@ -268,7 +267,7 @@ class WriteBuffer(io.BufferedIOBase):
using multiple workers. To use this mode set transfer_chunk_size to non-zero value.
chunks_limit parameter is used to limit memory usage (only in chunked write mode):
no more than chunks_limit chunks will be in progress at the same time.
no more than chunks_limit chunks will be queued for writing to the storage.
This class is not intended to be used directly: Checkpoint class is the main interface.
"""
Expand All @@ -290,13 +289,30 @@ def __init__(self, dfs, path, mode, open_flags, class_name,
self._oflags = open_flags
self._class_name = class_name
self._file_chunk_size = file_chunk_size
self._chunks_limit = chunks_limit
self._transfer_chunk_size = transfer_chunk_size
self._inprogress = set()

self._workers = []
if self._transfer_chunk_size > 0:
self._executor = ProcessPoolExecutor(
max_workers=workers, initializer=self._dfs.worker_init)
if chunks_limit == 0:
self._queue = Queue()
else:
self._queue = Queue(chunks_limit)

for _ in range(workers):
worker = Process(target=self._worker_fn, args=(self._queue,))
worker.start()
self._workers.append(worker)

def _worker_fn(self, queue):
self._dfs.worker_init()
while True:
work = queue.get()
if work is None:
break

(offset, chunk) = work
self._dfs.write(self._path, self._mode, self._oflags,
self._class_name, self._file_chunk_size, offset, chunk)

def write(self, data):
""" Writes data to the buffer."""
Expand All @@ -312,10 +328,6 @@ def write(self, data):

written = len(data)
while len(data) > 0:
if self._reached_memory_usage_limit():
self._wait_for_completion()
continue

fit = min(len(data), self._transfer_chunk_size - len(self._buffer))
chunk = data[:fit]
self._buffer.extend(chunk)
Expand Down Expand Up @@ -343,8 +355,11 @@ def close(self):

self._flush()
self._closed = True
if self._transfer_chunk_size > 0:
self._executor.shutdown(wait=True)

for _ in self._workers:
self._queue.put(None)
for worker in self._workers:
worker.join()

super().close()

Expand All @@ -353,10 +368,7 @@ def _flush(self):
if self.closed:
raise ValueError("I/O operation on closed file")

if len(self._buffer) == 0 and len(self._inprogress) == 0:
return

if self._transfer_chunk_size == 0:
if len(self._buffer) > 0 and self._transfer_chunk_size == 0:
self._dfs.write(self._path, self._mode, self._oflags,
self._class_name, self._file_chunk_size, 0, self._buffer)
return
Expand All @@ -365,31 +377,14 @@ def _flush(self):
self._submit_chunk(self._offset, self._buffer)
self._offset += len(self._buffer)

while len(self._inprogress) > 0:
self._wait_for_completion()

def _submit_chunk(self, offset, chunk):
""" Submits chunk for writing to the container """
""" Submits chunk for writing to the container.
self._inprogress.add(self._executor.submit(
self._dfs.write,
self._path, self._mode, self._oflags,
self._class_name, self._file_chunk_size, offset, chunk
))

def _wait_for_completion(self):
""" Waits for at least one of the in-progress writes to complete """

(completed, inprogress) = concurrent.futures.wait(
self._inprogress, return_when=FIRST_COMPLETED)
for future in completed:
future.result()

self._inprogress = inprogress
It will block if the queue is full and has a size limit, forcing the caller to wait
until some of the chunks are written to the storage.
"""

def _reached_memory_usage_limit(self):
""" Returns True if we reached the memory usage limit """
return self._chunks_limit > 0 and len(self._inprogress) >= self._chunks_limit
self._queue.put((offset, chunk))

@property
def closed(self):
Expand Down Expand Up @@ -533,28 +528,44 @@ def disconnect(self):
raise OSError(ret, os.strerror(ret))
self._dfs = None

def worker_fn(self, work, readdir_batch_size=READDIR_BATCH_SIZE):
def list_worker_fn(self, in_work, out_dirs, out_files, readdir_batch_size=READDIR_BATCH_SIZE):
"""
Reads the directory with indexed anchor.
Returns separate lists for files and directories, ready to be consumed by other workers.
Worker function to scan directory in parallel.
It expects to receive tuples (path, index) to scan the directory with an anchor index,
from the `in_work` queue.
It should emit tuples (scanned, to_scan) to the `out_dirs` queue, where `scanned` is the
number of scanned directories and `to_scan` is the list of directories to scan in parallel.
Upon completion it should emit the list of files in the `out_files` queue.
"""

(path, index) = work
self.worker_init()

dirs = []
files = []
ret = torch_shim.torch_list_with_anchor(DAOS_MAGIC, self._dfs,
path, index, files, dirs, readdir_batch_size
)
if ret != 0:
raise OSError(ret, os.strerror(ret), path)
result = []
while True:
work = in_work.get()
if work is None:
break

(path, index) = work

dirs = []
files = []
ret = torch_shim.torch_list_with_anchor(DAOS_MAGIC, self._dfs,
path, index, files, dirs, readdir_batch_size
)
if ret != 0:
raise OSError(ret, os.strerror(ret), path)

dirs = [chunk for d in dirs for chunk in self.split_dir_for_parallel_scan(
os.path.join(path, d))
]
dirs = [chunk for d in dirs for chunk in self.split_dir_for_parallel_scan(
os.path.join(path, d))
]
# Even if there are no dirs, we should emit the tuple to notify the main process
out_dirs.put((1, dirs))

files = [(os.path.join(path, fname), size) for (fname, size) in files]
return files, dirs
files = [(os.path.join(path, fname), size) for (fname, size) in files]
result.extend(files)

out_files.put(result)

def split_dir_for_parallel_scan(self, path):
"""
Expand Down Expand Up @@ -584,28 +595,36 @@ def parallel_list(self, path=None,
if not path.startswith(os.sep):
raise ValueError("relative path is unacceptable")

procs = []
work = Queue()
dirs = Queue()
files = Queue()
for _ in range(workers):
worker = Process(target=self.list_worker_fn, args=(
work, dirs, files, readdir_batch_size))
worker.start()
procs.append(worker)

queued = 0
processed = 0
for dir in self.split_dir_for_parallel_scan(path):

Check warning on line 610 in src/client/pydaos/torch/torch_api.py

View workflow job for this annotation

GitHub Actions / Pylint check

redefined-builtin, Redefining built-in 'dir'
work.put(dir)
queued += 1

while processed < queued:
(scanned, to_scan) = dirs.get()
processed += scanned
for dir in to_scan:
work.put(dir)
queued += 1

result = []
inprogress = set()
dirs = self.split_dir_for_parallel_scan(path)
with ProcessPoolExecutor(max_workers=workers, initializer=self.worker_init) as pool:
while True:
batch = dirs[:workers]
dirs = dirs[len(batch):]

futures = [pool.submit(self.worker_fn, dir, readdir_batch_size) for dir in batch]

inprogress.update(futures)
(complete, incomplete) = concurrent.futures.wait(
inprogress, return_when=FIRST_COMPLETED)

for fut in complete:
files, to_process = fut.result()
dirs.extend(to_process)
result.extend(files)

inprogress = incomplete
if len(dirs) == 0 and len(inprogress) == 0:
break
for _ in range(workers):
work.put(None)
result.extend(files.get())

for worker in procs:
worker.join()

return result

Expand Down
2 changes: 1 addition & 1 deletion src/tests/ftest/pytorch/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _test_checkpoint(self, pool, cont, writes, chunk_size=0, chunks_limit=0, wor
then reads written data back from it and compares it with the expected writes.
"""

self.log.info("Run checkpoint: writes=%s, chunk_size=%s, chunks_limit=%s, workers=%s",
self.log.info("Checkpoint test: writes=%s, chunk_size=%s, chunks_limit=%s, workers=%s",
len(writes), chunk_size, chunks_limit, workers)
chkp = Checkpoint(pool, cont, transfer_chunk_size=chunk_size, chunks_limit=chunks_limit,
workers=workers)
Expand Down

0 comments on commit d1e2f51

Please sign in to comment.