Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multi-storage-client backend for file open #1455

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion lhotse/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml
from packaging.version import parse as parse_version

from lhotse.utils import Pathlike, Pipe, SmartOpen, is_module_available, is_valid_url
from lhotse.utils import Pathlike, Pipe, SmartOpen, is_module_available, is_valid_url, replace_bucket_with_profile_name
from lhotse.workarounds import gzip_open_robust

# TODO: figure out how to use some sort of typing stubs
Expand Down Expand Up @@ -815,6 +815,82 @@ def handles_special_case(self, identifier: Pathlike) -> bool:
def is_applicable(self, identifier: Pathlike) -> bool:
return is_valid_url(identifier)


@lru_cache(1)
def get_lhotse_msc_override_protocols() -> Any:
return os.getenv("LHOTSE_MSC_OVERRIDE_PROTOCOLS", None)


@lru_cache(1)
def get_lhotse_msc_profile() -> Any:
return os.getenv("LHOTSE_MSC_PROFILE", None)


@lru_cache(1)
def get_lhotse_io_backend() -> Any:
return os.getenv("LHOTSE_IO_BACKEND", None)


MSC_PREFIX = "msc"

class MSCIOBackend(IOBackend):
"""
Uses multi-storage client to download data from object store
"""

def open(self, identifier: str, mode: str):
"""
Convert identifier if is not prefixed with msc, and use msc.open to access the file
For paths that are prefixed with msc, e.g. msc://profile/path/to/my/object1

For paths are yet to migrate to msc-compatible url, e.g. protocol://bucket/path/to/my/object2
1. override protocols provided by env LHOTSE_MSC_OVERRIDE_PROTOCOLS to msc: msc://bucket/path/to/my/object2
2. override the profile/bucket name by env LHOTSE_MSC_PROFILE if provided: msc://profile/path/to/my/object2,
if bucket name is not provided, then we expect the msc profile name to match with bucket name
"""

import multistorageclient as msc

# if url prefixed with msc, then return early
if identifier.startswith(f"{MSC_PREFIX}://"):
return msc.open(identifier, mode)

# override protocol if provided
lhotse_msc_override_protocols = get_lhotse_msc_override_protocols()
if lhotse_msc_override_protocols:
if "," in lhotse_msc_override_protocols:
override_protocol_list = lhotse_msc_override_protocols.split(",")
else:
override_protocol_list = [lhotse_msc_override_protocols]
for override_protocol in override_protocol_list:
if identifier.startswith(override_protocol):
identifier = identifier.replace(override_protocol, MSC_PREFIX)
break

# override bucket if provided
lhotse_msc_profile = get_lhotse_msc_profile()
if lhotse_msc_profile:
identifier = replace_bucket_with_profile_name(identifier, lhotse_msc_profile)

try:
file = msc.open(identifier, mode)
except Exception as e:
print(f"exception: {e}, identifier: {identifier}")
raise e

return file


@classmethod
def is_available(cls) -> bool:
return is_module_available("multistorageclient")

def handles_special_case(self, identifier: Pathlike) -> bool:
return str(identifier).startswith(f"{MSC_PREFIX}://")

def is_applicable(self, identifier: Pathlike) -> bool:
return is_valid_url(identifier)


class CompositeIOBackend(IOBackend):
"""
Expand Down Expand Up @@ -938,6 +1014,8 @@ def get_default_io_backend() -> "IOBackend":
RedirectIOBackend(),
PipeIOBackend(),
]
if MSCIOBackend.is_available():
backends.append(MSCIOBackend())
if AIStoreIOBackend.is_available():
# Try AIStore before other generalist backends,
# but only if it's installed and enabled via AIS_ENDPOINT env var.
Expand Down
8 changes: 7 additions & 1 deletion lhotse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
TypeVar,
Union,
)
from urllib.parse import urlparse
from urllib.parse import urlparse, urlunparse

import click
import numpy as np
Expand Down Expand Up @@ -1119,3 +1119,9 @@ def build_rng(seed: Union[int, Literal["trng"]]) -> random.Random:

def is_dill_enabled() -> bool:
return _LHOTSE_DILL_ENABLED or os.environ["LHOTSE_DILL_ENABLED"]


def replace_bucket_with_profile_name(identifier, profile_name):
parsed_identifier = urlparse(identifier)
updated_identifier = parsed_identifier._replace(netloc=profile_name)
return urlunparse(updated_identifier)
93 changes: 92 additions & 1 deletion test/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from tempfile import NamedTemporaryFile
import sys
import types

import pytest

Expand All @@ -18,7 +20,12 @@
store_manifest,
)
from lhotse.lazy import LazyJsonlIterator
from lhotse.serialization import SequentialJsonlWriter, load_manifest_lazy, open_best
from lhotse.serialization import (
MSCIOBackend,
SequentialJsonlWriter,
load_manifest_lazy,
open_best,
)
from lhotse.supervision import AlignmentItem
from lhotse.testing.dummies import DummyManifest
from lhotse.utils import fastcopy
Expand Down Expand Up @@ -516,3 +523,87 @@ def test_open_pipe_iter(tmp_path):
lines_read.append(l.strip())

assert lines_read == lines


@pytest.fixture
def clear_msc_env_caches():
# Clear caches before each test
from lhotse.serialization import get_lhotse_msc_profile, get_lhotse_msc_override_protocols
get_lhotse_msc_profile.cache_clear()
get_lhotse_msc_override_protocols.cache_clear()
yield

@pytest.mark.parametrize(
"identifier,expected_output,lhotse_msc_profile",
[
("msc://profile/path/to/object", "msc://profile/path/to/object", "profile"), # No change for msc:// prefix
("s3://bucket/path/to/object", "msc://bucket/path/to/object", ""), # Override only protocol
("s3://bucket/path", "msc://profile/path", "profile"), # Override protocol and bucket
],
)
def test_msc_io_backend_url_conversion(monkeypatch, clear_msc_env_caches, identifier, expected_output, lhotse_msc_profile):
# Mock environment variables
monkeypatch.setenv("LHOTSE_MSC_OVERRIDE_PROTOCOLS", "s3")
if lhotse_msc_profile:
monkeypatch.setenv("LHOTSE_MSC_PROFILE", lhotse_msc_profile)

# Mock multistorageclient.open to capture the transformed URL
class MockMSC:
def open(self, url, mode):
assert url == expected_output
return None

sys.modules["multistorageclient"] = MockMSC()
sys.modules["multistorageclient"].__spec__ = None

# Create backend and test URL transformation
backend = MSCIOBackend()
backend.open(identifier, mode="r")


@pytest.mark.parametrize(
"protocols",
[
"s3", # Single protocol
"s3,gs", # Multiple protocols
],
)
def test_msc_io_backend_multiple_protocols(monkeypatch, clear_msc_env_caches, protocols):

# Mock environment variables
monkeypatch.setenv("LHOTSE_MSC_OVERRIDE_PROTOCOLS", protocols)

# Mock multistorageclient.open to capture the transformed URL
class MockMSC:
def open(self, url, mode):
assert url.startswith("msc://")
return None

sys.modules["multistorageclient"] = MockMSC()
sys.modules["multistorageclient"].__spec__ = None

# Create backend and test URL transformation
backend = MSCIOBackend()

# Test with first protocol
backend.open("s3://bucket/path", mode="r")

if "," in protocols:
# Test with second protocol if multiple
backend.open("gs://bucket/path", mode="r")


def test_msc_io_backend_availability(monkeypatch):
from lhotse.serialization import MSCIOBackend

# Test when multistorageclient is not available
monkeypatch.setitem(sys.modules, "multistorageclient", None)
assert not MSCIOBackend.is_available()

# Test when multistorageclient is available
class MockMSC:
pass
mock_module = MockMSC()
mock_module.__spec__ = types.SimpleNamespace(name="multistorageclient")
monkeypatch.setitem(sys.modules, "multistorageclient", mock_module)
assert MSCIOBackend.is_available()