forked from URI-ABD/clam
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: nested cli sub-commands for plots
- Loading branch information
Showing
6 changed files
with
105 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
"""Plots and analysis of the results of the Cakes project.""" | ||
|
||
from . import scaling # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Plots and analysis of the search scaling results for Cakes.""" | ||
|
||
from . import reports # noqa: F401 | ||
from .plots import create_plots # noqa: F401 | ||
from .app import app # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""CLI command to create the plots for the scaling results of the Cakes search.""" | ||
|
||
import concurrent.futures | ||
import logging | ||
import pathlib | ||
|
||
import tqdm | ||
import typer | ||
|
||
from . import create_plots as _create_plots | ||
|
||
# Initialize the logger | ||
logger = logging.getLogger("scaling") | ||
logger.setLevel("INFO") | ||
|
||
app = typer.Typer() | ||
|
||
|
||
@app.command() | ||
def create_plots( | ||
input_dir: pathlib.Path = typer.Option( | ||
..., | ||
"--input-dir", | ||
"-i", | ||
help="The directory containing the reports from the scaling experiments.", | ||
exists=True, | ||
readable=True, | ||
file_okay=False, | ||
resolve_path=True, | ||
), | ||
output_dir: pathlib.Path = typer.Option( | ||
..., | ||
"--output-dir", | ||
"-o", | ||
help="The directory to save the plots.", | ||
exists=True, | ||
writable=True, | ||
file_okay=False, | ||
resolve_path=True, | ||
), | ||
) -> None: | ||
"""Create the plots for the scaling results of the Cakes search.""" | ||
logger.info(f"input_dir = {input_dir}") | ||
logger.info(f"output_dir = {output_dir}") | ||
|
||
files = list(input_dir.glob("*.json")) | ||
logger.info(f"Found {len(files)} json files.") | ||
|
||
with concurrent.futures.ProcessPoolExecutor() as executor: | ||
futures: list[concurrent.futures.Future[bool]] = [] | ||
for f in files: | ||
futures.append( | ||
executor.submit(_create_plots, f, False, output_dir), | ||
) | ||
|
||
for f in tqdm.tqdm( | ||
concurrent.futures.as_completed(futures), | ||
total=len(futures), | ||
desc="Processing files", | ||
): | ||
f.result() # type: ignore[attr-defined] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
"""Parser for the scaling results of the Cakes search.""" | ||
|
||
import json | ||
import pathlib | ||
import typing | ||
|
||
import pandas | ||
import pydantic | ||
|
||
|
||
class Report(pydantic.BaseModel): | ||
"""Report of the scaling results of the Cakes search.""" | ||
|
||
dataset: str | ||
metric: str | ||
base_cardinality: int | ||
dimensionality: int | ||
num_queries: int | ||
error_rate: float | ||
ks: list[int] | ||
csv_path: pathlib.Path = pathlib.Path(".").resolve() | ||
|
||
@staticmethod | ||
def from_json(json_path: pathlib.Path) -> "Report": | ||
"""Load the report from a JSON file.""" | ||
with json_path.open("r") as json_file: | ||
contents: dict[str, typing.Any] = json.load(json_file) | ||
contents["csv_path"] = json_path.parent.joinpath(contents.pop("csv_name")) | ||
return Report(**contents) | ||
|
||
def to_pandas(self) -> pandas.DataFrame: | ||
"""Read the CSV file into a pandas DataFrame.""" | ||
return pandas.read_csv(self.csv_path) |