From 59ba0e067de3cf419c7cd3f2073325a6b8cf00e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joan=20H=C3=A9risson?= Date: Fri, 12 Jul 2024 17:39:30 +0200 Subject: [PATCH] feat(sampler): fix component value --- icfree/sampler.py | 34 +++++++++++++++++++++++++++++----- tests/test_sampler.py | 19 ++++++++++++++----- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/icfree/sampler.py b/icfree/sampler.py index f48f5a8..8b685ed 100644 --- a/icfree/sampler.py +++ b/icfree/sampler.py @@ -3,8 +3,9 @@ import numpy as np import random from pyDOE2 import lhs +import ast -def generate_lhs_samples(input_file, num_samples, step, seed=None): +def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed=None): """ Generates Latin Hypercube Samples for components based on discrete ranges. @@ -12,6 +13,7 @@ def generate_lhs_samples(input_file, num_samples, step, seed=None): - input_file: Path to the input file containing components and their max values. - num_samples: Number of samples to generate. - step: Step size for creating discrete ranges. + - fixed_values: Dictionary of components with fixed values (optional). - seed: Random seed for reproducibility. Returns: @@ -30,7 +32,12 @@ def generate_lhs_samples(input_file, num_samples, step, seed=None): # Generate discrete ranges for each component for index, row in components_df.iterrows(): - component_range = np.arange(0, row['maxValue'] + step, step) + component_name = row['Component'] + if fixed_values and component_name in fixed_values: + # If the component has a fixed value, use a single-element array + component_range = np.array([fixed_values[component_name]]) + else: + component_range = np.arange(0, row['maxValue'] + step, step) discrete_ranges.append(component_range) # Determine the number of components @@ -48,7 +55,7 @@ def generate_lhs_samples(input_file, num_samples, step, seed=None): samples_df = pd.DataFrame(samples, columns=components_df['Component']) return samples_df -def main(input_file, output_file, num_samples, step=2.5, seed=None): +def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed=None): """ Main function to generate LHS samples and save them to a CSV file. @@ -57,10 +64,23 @@ def main(input_file, output_file, num_samples, step=2.5, seed=None): - output_file: Path to the output CSV file where samples will be written. - num_samples: Number of samples to generate. - step: Step size for creating discrete ranges (default: 2.5). + - fixed_values: Dictionary of components with fixed values (optional). - seed: Random seed for reproducibility (optional). """ + # Read the input file + components_df = pd.read_csv(input_file, sep='\t') + + # Get the list of components from the input file + component_names = components_df['Component'].tolist() + + # Check for fixed values that are not in the list of components + if fixed_values: + for component in fixed_values.keys(): + if component not in component_names: + print(f"Warning: Component '{component}' not found in the input file.") + # Generate LHS samples - samples_df = generate_lhs_samples(input_file, num_samples, step, seed) + samples_df = generate_lhs_samples(input_file, num_samples, step, fixed_values, seed) # Write the samples to a CSV file samples_df.to_csv(output_file, index=False) @@ -73,10 +93,14 @@ def main(input_file, output_file, num_samples, step=2.5, seed=None): parser.add_argument('output_file', type=str, help='Output CSV file path for the samples.') parser.add_argument('num_samples', type=int, help='Number of samples to generate.') parser.add_argument('--step', type=float, default=2.5, help='Step size for creating discrete ranges (default: 2.5).') + parser.add_argument('--fixed_values', type=str, default=None, help='Fixed values for components as a dictionary (e.g., \'{"Component1": 10, "Component2": 20}\')') parser.add_argument('--seed', type=int, default=None, help='Seed for random number generation for reproducibility (optional).') # Parse arguments args = parser.parse_args() + # Convert fixed_values argument from string to dictionary if provided + fixed_values = ast.literal_eval(args.fixed_values) if args.fixed_values else None + # Run the main function with the parsed arguments - main(args.input_file, args.output_file, args.num_samples, args.step, args.seed) + main(args.input_file, args.output_file, args.num_samples, args.step, fixed_values, args.seed) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 5017e6f..7976a26 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -19,7 +19,7 @@ def setUp(self): def test_generate_lhs_samples_normal(self, mock_read_csv): mock_read_csv.return_value = self.components_df - result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, self.seed) + result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, self.seed) self.assertEqual(result.shape, (self.num_samples, 3)) self.assertListEqual(list(result.columns), ['A', 'B', 'C']) @@ -28,7 +28,7 @@ def test_generate_lhs_samples_normal(self, mock_read_csv): def test_generate_lhs_samples_no_seed(self, mock_read_csv): mock_read_csv.return_value = self.components_df - result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None) + result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, None) self.assertEqual(result.shape, (self.num_samples, 3)) self.assertListEqual(list(result.columns), ['A', 'B', 'C']) @@ -39,7 +39,7 @@ def test_generate_lhs_samples_edge_case_zero_maxValue(self, mock_read_csv): edge_case_df.loc[0, 'maxValue'] = 0 # Set maxValue of component 'A' to 0 mock_read_csv.return_value = edge_case_df - result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, self.seed) + result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, None, self.seed) self.assertEqual(result.shape, (self.num_samples, 3)) self.assertTrue((result['A'] == 0).all()) # All values in column 'A' should be zero @@ -49,14 +49,23 @@ def test_generate_lhs_samples_invalid_step(self, mock_read_csv): mock_read_csv.return_value = self.components_df with self.assertRaises(IndexError): - generate_lhs_samples("fake_path.csv", self.num_samples, -2.5, self.seed) # Negative step size should raise an error + generate_lhs_samples("fake_path.csv", self.num_samples, -2.5, None, self.seed) # Negative step size should raise an error @patch("icfree.sampler.pd.read_csv") def test_generate_lhs_samples_invalid_input_file(self, mock_read_csv): mock_read_csv.side_effect = FileNotFoundError with self.assertRaises(FileNotFoundError): - generate_lhs_samples("invalid_path.csv", self.num_samples, self.step, self.seed) + generate_lhs_samples("invalid_path.csv", self.num_samples, self.step, None, self.seed) + + @patch("icfree.sampler.pd.read_csv") + def test_generate_lhs_samples_fix_component_value(self, mock_read_csv): + mock_read_csv.return_value = self.components_df + + result = generate_lhs_samples("fake_path.csv", self.num_samples, self.step, {'A': 5}, self.seed) + + self.assertEqual(result.shape, (self.num_samples, 3)) + self.assertTrue((result['A'] == 5).all()) if __name__ == "__main__": unittest.main()