diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 7b3b182..08c54c2 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -3527,12 +3527,40 @@ def _pwrite_syscall( ) bytes_just_written: int = os.pwrite(self._fd, data, offset) bytes_written += bytes_just_written - while bytes_written < expected_bytes_written and bytes_just_written > 0: + attempts: int = 0 + while bytes_written < expected_bytes_written and attempts < 3: # Writes larger than ~2 GiB may not complete in a single pwrite call offset += bytes_just_written with self._mv_suffix(data, bytes_written) as mv: + mv_size: int = mv.nbytes bytes_just_written = os.pwrite(self._fd, mv, offset) - bytes_written += bytes_just_written + if bytes_just_written > 0: + bytes_written += bytes_just_written + else: + # In case pwrite returns something strange + (logger.error if bytes_just_written < 0 else logger.info)( + ( + "pwrite: Supplementary write of %d bytes returned %d" + " with %d/%d bytes written (offset: %d)" + ), + mv_size, + bytes_just_written, + bytes_written, + expected_bytes_written, + offset, + ) + if bytes_just_written == 0: + if mv_size == 0: + logger.error("pwrite: Attempted to write 0 bytes") + break + attempts += 1 + logger.debug( + "pwrite: %s (attempt %d/3)", + "Retrying" if attempts < 3 else "Not retrying", + attempts, + ) + if attempts > 1: + time.sleep(0.2) if isinstance(verify, int) or verify: self._verify_bytes_written(bytes_written, expected_bytes_written) return bytes_written