diff --git a/lhotse/audio/backend.py b/lhotse/audio/backend.py index 64a9e1449..562387b56 100644 --- a/lhotse/audio/backend.py +++ b/lhotse/audio/backend.py @@ -478,11 +478,7 @@ def torchaudio_supports_ffmpeg() -> bool: # If user has disabled ffmpeg-torchaudio, we don't need to check the version. if not _FFMPEG_TORCHAUDIO_INFO_ENABLED: return False - - import torchaudio - from packaging import version - - return version.parse(torchaudio.__version__) >= version.parse("0.12.0") + return check_torchaudio_version_gt("0.12.0") @lru_cache(maxsize=1) @@ -494,14 +490,10 @@ def torchaudio_2_0_ffmpeg_enabled() -> bool: if not is_torchaudio_available(): return False - import torchaudio - from packaging import version - - ver = version.parse(torchaudio.__version__) - if ver >= version.parse("2.1.0"): + if check_torchaudio_version_gt("2.1.0"): # Enabled by default, disable with TORCHAUDIO_USE_BACKEND_DISPATCHER=0 return os.environ.get("TORCHAUDIO_USE_BACKEND_DISPATCHER", "1") == "1" - if ver >= version.parse("2.0"): + if check_torchaudio_version_gt("2.0"): # Disabled by default, enable with TORCHAUDIO_USE_BACKEND_DISPATCHER=1 return os.environ.get("TORCHAUDIO_USE_BACKEND_DISPATCHER", "0") == "1" return False @@ -513,10 +505,14 @@ def torchaudio_soundfile_supports_format() -> bool: Returns ``True`` when torchaudio version is at least 0.9.0, which has support for ``format`` keyword arg in ``torchaudio.save()``. """ + return check_torchaudio_version_gt("0.9.0") + + +def check_torchaudio_version_gt(version: str) -> bool: import torchaudio - from packaging import version + from packaging import version as _version - return version.parse(torchaudio.__version__) >= version.parse("0.9.0") + return _version.parse(torchaudio.__version__) >= _version.parse(version) def torchaudio_info( diff --git a/lhotse/shar/writers/audio.py b/lhotse/shar/writers/audio.py index 69db168a6..91b4d1291 100644 --- a/lhotse/shar/writers/audio.py +++ b/lhotse/shar/writers/audio.py @@ -9,6 +9,8 @@ from typing_extensions import Literal from lhotse import Recording +from lhotse.audio.backend import check_torchaudio_version_gt +from lhotse.augmentation import get_or_create_resampler from lhotse.shar.utils import to_shar_placeholder from lhotse.shar.writers.tar import TarWriter @@ -53,6 +55,10 @@ def __init__( self.save_fn = partial( torchaudio.backend.soundfile_backend.save, bits_per_sample=16 ) + if self.format == "opus": + assert check_torchaudio_version_gt( + "2.1.0" + ), "Writing OPUS files into Lhotse Shar requires torchaudio >= 2.1.0" def __enter__(self): self.tar_writer.__enter__() @@ -79,19 +85,28 @@ def write( sampling_rate: int, manifest: Recording, ) -> None: + # Resampling is required for some versions of OPUS encoders. + # First resample the manifest which only adjusts the metadata; + # then resample the audio array to 48kHz. + value = torch.from_numpy(value) + if self.format == "opus" and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE: + manifest = manifest.resample(OPUS_DEFAULT_SAMPLING_RATE) + value = get_or_create_resampler(sampling_rate, OPUS_DEFAULT_SAMPLING_RATE)( + value + ) + sampling_rate = OPUS_DEFAULT_SAMPLING_RATE + # Write binary data stream = BytesIO() self.save_fn( stream, - torch.from_numpy(value), + value, sampling_rate, format=self.format, ) self.tar_writer.write(f"{key}.{self.format}", stream) # Write text manifest afterwards - if self.format == "opus" and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE: - manifest = manifest.resample(OPUS_DEFAULT_SAMPLING_RATE) manifest = to_shar_placeholder(manifest) json_stream = BytesIO() print( diff --git a/test/shar/test_write.py b/test/shar/test_write.py index 1672b68b3..35a3f4487 100644 --- a/test/shar/test_write.py +++ b/test/shar/test_write.py @@ -6,6 +6,7 @@ import pytest from lhotse import CutSet +from lhotse.audio.backend import check_torchaudio_version_gt from lhotse.lazy import LazyJsonlIterator from lhotse.shar import AudioTarWriter, SharWriter, TarIterator, TarWriter from lhotse.testing.dummies import DummyManifest @@ -65,7 +66,27 @@ def test_tar_writer_pipe(tmp_path: Path): assert f2.read() == b"test" -@pytest.mark.parametrize("format", ["wav", "flac", "mp3", "opus"]) +@pytest.mark.parametrize( + "format", + [ + "wav", + pytest.param( + "flac", + marks=pytest.mark.skipif( + not check_torchaudio_version_gt("0.12.1"), + reason="Torchaudio v0.12.1 or greater is required.", + ), + ), + # "mp3", # apparently doesn't work in CI, mp3 encoder is missing + pytest.param( + "opus", + marks=pytest.mark.skipif( + not check_torchaudio_version_gt("2.1.0"), + reason="Torchaudio v2.1.0 or greater is required.", + ), + ), + ], +) def test_audio_tar_writer(tmp_path: Path, format: str): from lhotse.testing.dummies import dummy_recording