Skip to content

Commit

Permalink
Fix for Python 3.7
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelasko committed Dec 15, 2023
1 parent 512579b commit 2970470
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
22 changes: 9 additions & 13 deletions lhotse/audio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,7 @@ def torchaudio_supports_ffmpeg() -> bool:
# If user has disabled ffmpeg-torchaudio, we don't need to check the version.
if not _FFMPEG_TORCHAUDIO_INFO_ENABLED:
return False

import torchaudio
from packaging import version

return version.parse(torchaudio.__version__) >= version.parse("0.12.0")
return check_torchaudio_version_gt("0.12.0")


@lru_cache(maxsize=1)
Expand All @@ -494,14 +490,10 @@ def torchaudio_2_0_ffmpeg_enabled() -> bool:
if not is_torchaudio_available():
return False

import torchaudio
from packaging import version

ver = version.parse(torchaudio.__version__)
if ver >= version.parse("2.1.0"):
if check_torchaudio_version_gt("2.1.0"):
# Enabled by default, disable with TORCHAUDIO_USE_BACKEND_DISPATCHER=0
return os.environ.get("TORCHAUDIO_USE_BACKEND_DISPATCHER", "1") == "1"
if ver >= version.parse("2.0"):
if check_torchaudio_version_gt("2.0"):
# Disabled by default, enable with TORCHAUDIO_USE_BACKEND_DISPATCHER=1
return os.environ.get("TORCHAUDIO_USE_BACKEND_DISPATCHER", "0") == "1"
return False
Expand All @@ -513,10 +505,14 @@ def torchaudio_soundfile_supports_format() -> bool:
Returns ``True`` when torchaudio version is at least 0.9.0, which
has support for ``format`` keyword arg in ``torchaudio.save()``.
"""
return check_torchaudio_version_gt("0.9.0")


def check_torchaudio_version_gt(version: str) -> bool:
import torchaudio
from packaging import version
from packaging import version as _version

return version.parse(torchaudio.__version__) >= version.parse("0.9.0")
return _version.parse(torchaudio.__version__) >= _version.parse(version)


def torchaudio_info(
Expand Down
21 changes: 18 additions & 3 deletions lhotse/shar/writers/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing_extensions import Literal

from lhotse import Recording
from lhotse.audio.backend import check_torchaudio_version_gt
from lhotse.augmentation import get_or_create_resampler
from lhotse.shar.utils import to_shar_placeholder
from lhotse.shar.writers.tar import TarWriter

Expand Down Expand Up @@ -53,6 +55,10 @@ def __init__(
self.save_fn = partial(
torchaudio.backend.soundfile_backend.save, bits_per_sample=16
)
if self.format == "opus":
assert check_torchaudio_version_gt(

Check warning on line 59 in lhotse/shar/writers/audio.py

View check run for this annotation

Codecov / codecov/patch

lhotse/shar/writers/audio.py#L59

Added line #L59 was not covered by tests
"2.1.0"
), "Writing OPUS files into Lhotse Shar requires torchaudio >= 2.1.0"

def __enter__(self):
self.tar_writer.__enter__()
Expand All @@ -79,19 +85,28 @@ def write(
sampling_rate: int,
manifest: Recording,
) -> None:
# Resampling is required for some versions of OPUS encoders.
# First resample the manifest which only adjusts the metadata;
# then resample the audio array to 48kHz.
value = torch.from_numpy(value)
if self.format == "opus" and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE:
manifest = manifest.resample(OPUS_DEFAULT_SAMPLING_RATE)
value = get_or_create_resampler(sampling_rate, OPUS_DEFAULT_SAMPLING_RATE)(

Check warning on line 94 in lhotse/shar/writers/audio.py

View check run for this annotation

Codecov / codecov/patch

lhotse/shar/writers/audio.py#L93-L94

Added lines #L93 - L94 were not covered by tests
value
)
sampling_rate = OPUS_DEFAULT_SAMPLING_RATE

Check warning on line 97 in lhotse/shar/writers/audio.py

View check run for this annotation

Codecov / codecov/patch

lhotse/shar/writers/audio.py#L97

Added line #L97 was not covered by tests

# Write binary data
stream = BytesIO()
self.save_fn(
stream,
torch.from_numpy(value),
value,
sampling_rate,
format=self.format,
)
self.tar_writer.write(f"{key}.{self.format}", stream)

# Write text manifest afterwards
if self.format == "opus" and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE:
manifest = manifest.resample(OPUS_DEFAULT_SAMPLING_RATE)
manifest = to_shar_placeholder(manifest)
json_stream = BytesIO()
print(
Expand Down
23 changes: 22 additions & 1 deletion test/shar/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from lhotse import CutSet
from lhotse.audio.backend import check_torchaudio_version_gt
from lhotse.lazy import LazyJsonlIterator
from lhotse.shar import AudioTarWriter, SharWriter, TarIterator, TarWriter
from lhotse.testing.dummies import DummyManifest
Expand Down Expand Up @@ -65,7 +66,27 @@ def test_tar_writer_pipe(tmp_path: Path):
assert f2.read() == b"test"


@pytest.mark.parametrize("format", ["wav", "flac", "mp3", "opus"])
@pytest.mark.parametrize(
"format",
[
"wav",
pytest.param(
"flac",
marks=pytest.mark.skipif(
not check_torchaudio_version_gt("0.12.1"),
reason="Torchaudio v0.12.1 or greater is required.",
),
),
# "mp3", # apparently doesn't work in CI, mp3 encoder is missing
pytest.param(
"opus",
marks=pytest.mark.skipif(
not check_torchaudio_version_gt("2.1.0"),
reason="Torchaudio v2.1.0 or greater is required.",
),
),
],
)
def test_audio_tar_writer(tmp_path: Path, format: str):
from lhotse.testing.dummies import dummy_recording

Expand Down

0 comments on commit 2970470

Please sign in to comment.