-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from DiamondLightSource/add_csv_reader
Add csv reader
- Loading branch information
Showing
8 changed files
with
376 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,78 @@ | ||
"""Interface for ``python -m bimorph_mirror_analysis``.""" | ||
|
||
from argparse import ArgumentParser | ||
from collections.abc import Sequence | ||
import datetime | ||
|
||
import numpy as np | ||
import typer | ||
|
||
from bimorph_mirror_analysis.maths import find_voltages | ||
from bimorph_mirror_analysis.read_file import read_bluesky_plan_output | ||
|
||
from . import __version__ | ||
|
||
__all__ = ["main"] | ||
|
||
app = typer.Typer() | ||
|
||
def main(args: Sequence[str] | None = None) -> None: | ||
"""Argument parser for the CLI.""" | ||
parser = ArgumentParser() | ||
parser.add_argument( | ||
"-v", | ||
"--version", | ||
action="version", | ||
version=__version__, | ||
|
||
@app.command(name=None) | ||
def calculate_voltages( | ||
file_path: str = typer.Argument(help="The path to the csv file to be read."), | ||
output_path: str | None = typer.Option( | ||
None, | ||
help="The path to save the output\ | ||
optimal voltages to, optional.", | ||
), | ||
): | ||
file_type = file_path.split(".")[-1] | ||
optimal_voltages = calculate_optimal_voltages(file_path) | ||
optimal_voltages = np.round(optimal_voltages, 2) | ||
date = datetime.datetime.now().date() | ||
|
||
if output_path is None: | ||
output_path = f"{file_path.replace(f'.{file_type}', '')}\ | ||
_optimal_voltages_{date}.csv" | ||
|
||
np.savetxt( | ||
output_path, | ||
optimal_voltages, | ||
fmt="%.2f", | ||
) | ||
parser.parse_args(args) | ||
print(f"The optimal voltages have been saved to {output_path}") | ||
print( | ||
f"The optimal voltages are: [{', '.join([str(i) for i in optimal_voltages])}]" | ||
) | ||
|
||
|
||
def version_callback(value: bool): | ||
if value: | ||
typer.echo(f"Version: {__version__}") | ||
raise typer.Exit() | ||
|
||
|
||
@app.callback() | ||
def main( | ||
version: bool = typer.Option( | ||
None, | ||
"--version", | ||
"-v", | ||
callback=version_callback, | ||
is_eager=True, | ||
help="Show the application's version and exit", | ||
), | ||
): | ||
pass | ||
|
||
|
||
def calculate_optimal_voltages(file_path: str) -> np.typing.NDArray[np.float64]: | ||
pivoted, initial_voltages, increment = read_bluesky_plan_output(file_path) | ||
# numpy array of pencil beam scans | ||
data = pivoted[pivoted.columns[1:]].to_numpy() # type: ignore | ||
|
||
voltage_adjustments = find_voltages(data, increment) # type: ignore | ||
optimal_voltages = initial_voltages + voltage_adjustments | ||
return optimal_voltages # type: ignore | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
def read_bluesky_plan_output( | ||
filepath: str, | ||
) -> tuple[pd.DataFrame, np.typing.NDArray[np.float64], float]: | ||
"""Read the csv file putput by the bluesky plan | ||
Reads the file and returns the dataframe with individual pecil beam scans as | ||
columns, the initial voltages and the voltage increment. | ||
Args: | ||
filepath (str): The path to the csv file to be read. | ||
Returns: | ||
A tuple containing the DataFrame, the initial voltages array and the voltage | ||
incrememnt. | ||
""" | ||
data = pd.read_csv(filepath) # type: ignore | ||
data = data.apply(pd.to_numeric, errors="coerce") # type: ignore | ||
|
||
voltage_cols = [col for col in data.columns if "voltage" in col] | ||
initial_voltages = data.loc[0, voltage_cols].to_numpy() # type: ignore | ||
final_voltages = data.loc[len(data) - 1, voltage_cols].to_numpy() # type: ignore | ||
|
||
voltage_increment = final_voltages[0] - initial_voltages[0] # type: ignore | ||
|
||
pivoted = pd.pivot_table( # type: ignore | ||
data, | ||
values="centroid_position_x", | ||
index=["slit_position_x"], | ||
columns=["pencil_beam_scan_number"], | ||
) | ||
pivoted.columns = ["pencil_beam_scan_" + str(col) for col in pivoted.columns] | ||
pivoted.reset_index(inplace=True) | ||
return pivoted, initial_voltages, voltage_increment # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from bimorph_mirror_analysis.__main__ import calculate_optimal_voltages | ||
from bimorph_mirror_analysis.maths import find_voltages | ||
|
||
|
||
def test_calculate_optimal_voltages_mocked(raw_data_pivoted: pd.DataFrame): | ||
with ( | ||
patch( | ||
"bimorph_mirror_analysis.__main__.read_bluesky_plan_output" | ||
) as mock_read_bluesky_plan_output, | ||
patch("bimorph_mirror_analysis.__main__.find_voltages") as mock_find_voltages, | ||
): | ||
# set the mock return values | ||
mock_read_bluesky_plan_output.return_value = ( | ||
raw_data_pivoted, | ||
np.array([0.0, 0.0, 0.0]), | ||
100, | ||
) | ||
mock_find_voltages.side_effect = find_voltages | ||
voltages = calculate_optimal_voltages("input_file") | ||
voltages = np.round(voltages, 2) | ||
# assert correct voltages calculated | ||
np.testing.assert_almost_equal(voltages, np.array([72.14, 50.98, 18.59])) | ||
|
||
# assert mock was called | ||
mock_read_bluesky_plan_output.assert_called() | ||
mock_read_bluesky_plan_output.assert_called_with("input_file") | ||
mock_find_voltages.assert_called() | ||
expected_data = raw_data_pivoted[raw_data_pivoted.columns[1:]].to_numpy() # type: ignore | ||
np.testing.assert_array_equal(mock_find_voltages.call_args[0][0], expected_data) # type: ignore | ||
np.testing.assert_almost_equal(mock_find_voltages.call_args[0][1], 100) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from bimorph_mirror_analysis.read_file import read_bluesky_plan_output | ||
|
||
|
||
def test_read_raw_data(raw_data: pd.DataFrame, raw_data_pivoted: pd.DataFrame): | ||
with patch("bimorph_mirror_analysis.read_file.pd.read_csv") as mock_read_csv: | ||
mock_read_csv.return_value = raw_data | ||
pivoted, initial_voltages, increment = read_bluesky_plan_output("input_path") | ||
pd.testing.assert_frame_equal(pivoted, raw_data_pivoted) | ||
np.testing.assert_array_equal(initial_voltages, np.array([0.0, 0.0, 0.0])) | ||
np.testing.assert_equal(increment, np.float64(100.0)) | ||
mock_read_csv.assert_called() | ||
|
||
|
||
@pytest.mark.xfail( | ||
reason="This test is expected to fail, the incrememnt should be 100, not 101" | ||
) | ||
def test_read_raw_data_fail(raw_data_pivoted: pd.DataFrame): | ||
with patch( | ||
"bimorph_mirror_analysis.read_file.read_bluesky_plan_output" | ||
) as mock_read_bluesky_plan_output: | ||
mock_read_bluesky_plan_output.return_value = ( | ||
raw_data_pivoted, | ||
np.array([0.0, 0.0, 0.0]), | ||
np.float64(101.0), | ||
) | ||
pivoted, initial_voltages, increment = mock_read_bluesky_plan_output() | ||
expected_output = pd.read_csv("tests/data/raw_data_pivoted.csv") # type: ignore | ||
pd.testing.assert_frame_equal(pivoted, expected_output) | ||
np.testing.assert_array_equal(initial_voltages, np.array([0.0, 0.0, 0.0])) | ||
np.testing.assert_equal(increment, np.float64(100.0)) |
Oops, something went wrong.