Skip to content

Commit

Permalink
Add pass through for filesystem storage_options (#69)
Browse files Browse the repository at this point in the history
* add storage_options as a pass through to filesystem

* add storage_options as a pass through to filesystem async

* formatting fixes

* add storage_options as a pass through to dashboard directly

* add storage_options as a pass through to dashboard via cli
  • Loading branch information
joe-wolfe21 authored Apr 7, 2021
1 parent d7893dd commit 21afe5f
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 19 deletions.
16 changes: 13 additions & 3 deletions rubicon/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def cli():
# Top level CLI commands


@cli.command()
@cli.command(
context_settings=dict(
ignore_unknown_options=True,
)
)
@click.option(
"--root-dir",
type=click.STRING,
Expand All @@ -35,9 +39,15 @@ def cli():
@click.option(
"--debug", "-d", type=click.BOOL, help="Whether or not to run in debug mode.", default=False
)
def ui(root_dir, host, port, debug, page_size):
@click.argument("storage_options", nargs=-1, type=click.UNPROCESSED)
def ui(root_dir, host, port, debug, page_size, storage_options):
"""Launch the Rubicon Dashboard."""
dashboard = Dashboard("filesystem", root_dir, page_size=page_size)
# convert the additional storage options into a dict
# coming in as: ('--key1', 'one', '--key2', 'two')
storage_options_dict = {
storage_options[i][2:]: storage_options[i + 1] for i in range(0, len(storage_options), 2)
}
dashboard = Dashboard("filesystem", root_dir, page_size=page_size, **storage_options_dict)

server_kwargs = dict(debug=debug, port=port, host=host)
server_kwargs = {k: v for k, v in server_kwargs.items() if v is not None}
Expand Down
9 changes: 7 additions & 2 deletions rubicon/client/asynchronous/rubicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ class Rubicon(SyncRubicon):
True to use the `git` command to automatically log relevant repository
information to projects and experiments logged with this client instance,
False otherwise. Defaults to False.
storage_options : dict, optional
Additional keyword arguments specific to the protocol being chosen. They
are passed directly to the underlying filesystem class.
"""

def __init__(self, persistence="filesystem", root_dir=None, auto_git_enabled=False):
self.config = Config(persistence, root_dir, auto_git_enabled)
def __init__(
self, persistence="filesystem", root_dir=None, auto_git_enabled=False, **storage_options
):
self.config = Config(persistence, root_dir, auto_git_enabled, **storage_options)

async def create_project(self, name, description=None, github_url=None, training_metadata=None):
"""Overrides `rubicon.client.Rubicon.create_experiment`
Expand Down
10 changes: 8 additions & 2 deletions rubicon/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class Config:
True to use the `git` command to automatically log relevant repository
information to projects and experiments logged with this client instance,
False otherwise. Defaults to False.
storage_options : dict, optional
Additional keyword arguments specific to the protocol being chosen. They
are passed directly to the underlying filesystem class.
"""

PERSISTENCE_TYPES = ["filesystem", "memory"]
Expand All @@ -32,10 +35,13 @@ class Config:
"filesystem-s3": S3Repository,
}

def __init__(self, persistence=None, root_dir=None, is_auto_git_enabled=False):
def __init__(
self, persistence=None, root_dir=None, is_auto_git_enabled=False, **storage_options
):
self.persistence, self.root_dir, self.is_auto_git_enabled = self._load_config(
persistence, root_dir, is_auto_git_enabled
)
self.storage_options = storage_options

self.repository = self._get_repository()

Expand Down Expand Up @@ -85,4 +91,4 @@ def _get_repository(self):
+ f"`protocol` (from `root_dir`): {protocol}"
)

return repository(self.root_dir)
return repository(self.root_dir, **self.storage_options)
9 changes: 7 additions & 2 deletions rubicon/client/rubicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ class Rubicon:
True to use the `git` command to automatically log relevant repository
information to projects and experiments logged with this client instance,
False otherwise. Defaults to False.
storage_options : dict, optional
Additional keyword arguments specific to the protocol being chosen. They
are passed directly to the underlying filesystem class.
"""

def __init__(self, persistence="filesystem", root_dir=None, auto_git_enabled=False):
self.config = Config(persistence, root_dir, auto_git_enabled)
def __init__(
self, persistence="filesystem", root_dir=None, auto_git_enabled=False, **storage_options
):
self.config = Config(persistence, root_dir, auto_git_enabled, **storage_options)

@property
def repository(self):
Expand Down
9 changes: 7 additions & 2 deletions rubicon/repository/asynchronous/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,18 @@ class AsynchronousBaseRepository(BaseRepository):
The event loop the asynchronous calling program is running on.
It should not be necessary to provide this parameter in
standard asynchronous operating cases.
storage_options : dict, optional
Additional keyword arguments that are passed directly to
the underlying filesystem class.
"""

PROTOCOL = None

def __init__(self, root_dir, loop=None):
def __init__(self, root_dir, loop=None, **storage_options):
self.root_dir = root_dir
self.filesystem = fsspec.filesystem(self.PROTOCOL, asynchronous=True, loop=loop)
self.filesystem = fsspec.filesystem(
self.PROTOCOL, asynchronous=True, loop=loop, **storage_options
)

self._is_connected = False

Expand Down
7 changes: 5 additions & 2 deletions rubicon/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ class BaseRepository:
root_dir : str
Absolute path to the root directory to persist Rubicon
data to.
storage_options : dict, optional
Additional keyword arguments that are passed directly to
the underlying filesystem class.
"""

def __init__(self, root_dir):
self.filesystem = fsspec.filesystem(self.PROTOCOL)
def __init__(self, root_dir, **storage_options):
self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options)
self.root_dir = root_dir.rstrip("/")

def _ls_directories_only(self, path):
Expand Down
7 changes: 5 additions & 2 deletions rubicon/repository/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ class MemoryRepository(LocalRepository):
the in-memory filesystem. This does not need to be
specified unless interacting with an already created
in-memory filesystem.
storage_options : dict, optional
Additional keyword arguments that are passed directly to
the underlying filesystem class.
"""

PROTOCOL = "memory"

def __init__(self, root_dir=None):
self.filesystem = fsspec.filesystem(self.PROTOCOL)
def __init__(self, root_dir=None, **storage_options):
self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options)
self.root_dir = root_dir.rstrip("/") if root_dir is not None else "/root"

self.filesystem.mkdir(self.root_dir)
Expand Down
7 changes: 5 additions & 2 deletions rubicon/ui/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ class Dashboard:
page_size : int, optional
The number of rows that will be displayed on a page within the
experiment table.
storage_options : dict, optional
Additional keyword arguments specific to the protocol being chosen. They
are passed directly to the underlying filesystem class.
"""

def __init__(self, persistence, root_dir=None, page_size=10):
def __init__(self, persistence, root_dir=None, page_size=10, **storage_options):
self._app = app
self._app._page_size = page_size
self.rubicon_model = RubiconModel(persistence, root_dir)
self.rubicon_model = RubiconModel(persistence, root_dir, **storage_options)
self._app._rubicon_model = self.rubicon_model

self._app.layout = html.Div(
Expand Down
7 changes: 5 additions & 2 deletions rubicon/ui/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ class RubiconModel:
Absolute or relative filepath of the root directory holding Rubicon data.
Use absolute path for best performance. Defaults to the local filesystem.
Prefix with s3:// to use s3 instead.
storage_options : dict, optional
Additional keyword arguments specific to the protocol being chosen. They
are passed directly to the underlying filesystem class.
"""

def __init__(self, persistence, root_dir):
def __init__(self, persistence, root_dir, **storage_options):
self._rubicon_cls = AsynRubicon if root_dir.startswith("s3") else Rubicon
self._rubicon = self._rubicon_cls(persistence, root_dir)
self._rubicon = self._rubicon_cls(persistence, root_dir, **storage_options)

self._projects = []
self._selected_project = None
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/client/asynchronous/test_asyn_rubicon_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
import dask.dataframe as dd

from rubicon import domain
from rubicon.client.asynchronous import Rubicon
from rubicon.exceptions import RubiconException


def test_repository_storage_options():
storage_options = {"key": "secret"}
rubicon_s3 = Rubicon(persistence="filesystem", root_dir="s3://nothing", **storage_options)

assert rubicon_s3.config.repository.filesystem.storage_options["key"] == "secret"


def test_create_project(asyn_client_w_mock_repo):
rubicon = asyn_client_w_mock_repo

Expand Down
9 changes: 9 additions & 0 deletions tests/unit/client/test_rubicon_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def test_set_repository(rubicon_client):
assert rubicon.config.repository == test_repo


def test_repository_storage_options():
storage_options = {"key": "secret"}
rubicon_memory = Rubicon(persistence="memory", root_dir="./", **storage_options)
rubicon_s3 = Rubicon(persistence="filesystem", root_dir="s3://nothing", **storage_options)

assert rubicon_memory.config.repository.filesystem.storage_options["key"] == "secret"
assert rubicon_s3.config.repository.filesystem.storage_options["key"] == "secret"


def test_get_github_url(rubicon_client, mock_completed_process_git):
rubicon = rubicon_client

Expand Down

0 comments on commit 21afe5f

Please sign in to comment.