diff --git a/src/forcedphot/__init__.py b/src/forcedphot/__init__.py index c753603..e69de29 100644 --- a/src/forcedphot/__init__.py +++ b/src/forcedphot/__init__.py @@ -1,5 +0,0 @@ -from .example_module import greetings, meaning -from .horizons_interface import HorizonsInterface -from .local_dataclasses import QueryInput, QueryResult, EphemerisData - -__all__ = ["greetings", "meaning", "HorizonsInterface", "QueryInput", "QueryResult", "EphemerisData"] diff --git a/src/forcedphot/horizons_interface.py b/src/forcedphot/horizons_interface.py index 7d104d9..463cd17 100644 --- a/src/forcedphot/horizons_interface.py +++ b/src/forcedphot/horizons_interface.py @@ -6,7 +6,8 @@ import pandas as pd from astropy.time import Time from astroquery.jplhorizons import Horizons -from local_dataclasses import EphemerisData, QueryInput, QueryResult + +from .local_dataclasses import EphemerisData, QueryInput, QueryResult class HorizonsInterface: @@ -37,7 +38,7 @@ class HorizonsInterface: logger = logging.getLogger(__name__) # Rubin location - DEFAULT_OBSERVER_LOCATION = 'X05' + DEFAULT_OBSERVER_LOCATION = "X05" def __init__(self, observer_location=DEFAULT_OBSERVER_LOCATION): """ @@ -52,7 +53,6 @@ def __init__(self, observer_location=DEFAULT_OBSERVER_LOCATION): """ self.observer_location = observer_location - def query_single_range(self, query: QueryInput) -> QueryResult: """ Query ephemeris for a single time range. @@ -78,33 +78,41 @@ def query_single_range(self, query: QueryInput) -> QueryResult: """ try: start_time = time.time() - obj = Horizons(id_type='smallbody', id=query.target, location=self.observer_location, - epochs={'start': query.start.iso, 'stop': query.end.iso, 'step': query.step}) + obj = Horizons( + id_type="smallbody", + id=query.target, + location=self.observer_location, + epochs={"start": query.start.iso, "stop": query.end.iso, "step": query.step}, + ) ephemeris = obj.ephemerides() end_time = time.time() - self.logger.info(f"Query for range {query.start} to {query.end} successful for target" - f"{query.target}. Time taken: {end_time - start_time:.2f} seconds.") + self.logger.info( + f"Query for range {query.start} to {query.end} successful for target" + f"{query.target}. Time taken: {end_time - start_time:.2f} seconds." + ) ephemeris_data = EphemerisData( - datetime_jd=Time(ephemeris['datetime_jd'], format='jd'), - datetime_iso=Time(Time(ephemeris['datetime_jd'], format='jd').iso, format='iso'), - RA_deg=np.array(ephemeris['RA']), - DEC_deg=np.array(ephemeris['DEC']), - RA_rate_arcsec_per_h=np.array(ephemeris['RA_rate']), - DEC_rate_arcsec_per_h=np.array(ephemeris['DEC_rate']), - AZ_deg=np.array(ephemeris['AZ']), - EL_deg=np.array(ephemeris['EL']), - r_au=np.array(ephemeris['r']), - delta_au=np.array(ephemeris['delta']), - V_mag=np.array(ephemeris['V']), - alpha_deg=np.array(ephemeris['alpha']), - RSS_3sigma_arcsec=np.array(ephemeris['RSS_3sigma']) + datetime_jd=Time(ephemeris["datetime_jd"], format="jd"), + datetime_iso=Time(Time(ephemeris["datetime_jd"], format="jd").iso, format="iso"), + RA_deg=np.array(ephemeris["RA"]), + DEC_deg=np.array(ephemeris["DEC"]), + RA_rate_arcsec_per_h=np.array(ephemeris["RA_rate"]), + DEC_rate_arcsec_per_h=np.array(ephemeris["DEC_rate"]), + AZ_deg=np.array(ephemeris["AZ"]), + EL_deg=np.array(ephemeris["EL"]), + r_au=np.array(ephemeris["r"]), + delta_au=np.array(ephemeris["delta"]), + V_mag=np.array(ephemeris["V"]), + alpha_deg=np.array(ephemeris["alpha"]), + RSS_3sigma_arcsec=np.array(ephemeris["RSS_3sigma"]), ) return QueryResult(query.target, query.start, query.end, ephemeris_data) except Exception as e: - self.logger.error(f"An error occurred during query for range {query.start} to {query.end}" - f"for target {query.target}: {e}") + self.logger.error( + f"An error occurred during query for range {query.start} to {query.end}" + f"for target {query.target}: {e}" + ) return None @classmethod @@ -152,9 +160,9 @@ def query_ephemeris_from_csv(cls, csv_filename: str, observer_location=DEFAULT_O for _index, row in df.iterrows(): query = QueryInput( target=row.iloc[0], - start=Time(row.iloc[1], scale='utc'), - end=Time(row.iloc[2], scale='utc'), - step=row.iloc[3] + start=Time(row.iloc[1], scale="utc"), + end=Time(row.iloc[2], scale="utc"), + step=row.iloc[3], ) # Calculate the total number of instances @@ -166,13 +174,18 @@ def query_ephemeris_from_csv(cls, csv_filename: str, observer_location=DEFAULT_O # Check if multiple queries are needed if step_instances > max_instances: - cls.logger.info(f"Total instances exceed 10,000 for target {query.target}. Splitting" - f"the queries.") + cls.logger.info( + f"Total instances exceed 10,000 for target {query.target}. Splitting" f"the queries." + ) time_splits = int(step_instances // max_instances) + 1 - time_ranges = [(query.start + i * (total_days / time_splits) * u.day, - query.start + (i + 1) * (total_days / time_splits) * u.day) - for i in range(time_splits)] + time_ranges = [ + ( + query.start + i * (total_days / time_splits) * u.day, + query.start + (i + 1) * (total_days / time_splits) * u.day, + ) + for i in range(time_splits) + ] else: time_ranges = [(query.start, query.end)] @@ -180,56 +193,85 @@ def query_ephemeris_from_csv(cls, csv_filename: str, observer_location=DEFAULT_O # Run queries sequentially for start, end in time_ranges: - result = horizons_interface.query_single_range(QueryInput(query.target, start, end, query.step)) + result = horizons_interface.query_single_range( + QueryInput(query.target, start, end, query.step) + ) if result is not None: - all_ephemeris.datetime_jd = Time(np.concatenate((all_ephemeris.datetime_jd.jd, result.ephemeris.datetime_jd.jd)), format='jd') - all_ephemeris.datetime_iso = Time(np.concatenate((all_ephemeris.datetime_iso.iso, result.ephemeris.datetime_iso.iso)), format='iso') + all_ephemeris.datetime_jd = Time( + np.concatenate((all_ephemeris.datetime_jd.jd, result.ephemeris.datetime_jd.jd)), + format="jd", + ) + all_ephemeris.datetime_iso = Time( + np.concatenate( + (all_ephemeris.datetime_iso.iso, result.ephemeris.datetime_iso.iso) + ), + format="iso", + ) all_ephemeris.RA_deg = np.concatenate((all_ephemeris.RA_deg, result.ephemeris.RA_deg)) - all_ephemeris.DEC_deg = np.concatenate((all_ephemeris.DEC_deg, result.ephemeris.DEC_deg)) - all_ephemeris.RA_rate_arcsec_per_h = np.concatenate((all_ephemeris.RA_rate_arcsec_per_h, result.ephemeris.RA_rate_arcsec_per_h)) - all_ephemeris.DEC_rate_arcsec_per_h = np.concatenate((all_ephemeris.DEC_rate_arcsec_per_h, result.ephemeris.DEC_rate_arcsec_per_h)) + all_ephemeris.DEC_deg = np.concatenate( + (all_ephemeris.DEC_deg, result.ephemeris.DEC_deg) + ) + all_ephemeris.RA_rate_arcsec_per_h = np.concatenate( + (all_ephemeris.RA_rate_arcsec_per_h, result.ephemeris.RA_rate_arcsec_per_h) + ) + all_ephemeris.DEC_rate_arcsec_per_h = np.concatenate( + (all_ephemeris.DEC_rate_arcsec_per_h, result.ephemeris.DEC_rate_arcsec_per_h) + ) all_ephemeris.AZ_deg = np.concatenate((all_ephemeris.AZ_deg, result.ephemeris.AZ_deg)) all_ephemeris.EL_deg = np.concatenate((all_ephemeris.EL_deg, result.ephemeris.EL_deg)) all_ephemeris.r_au = np.concatenate((all_ephemeris.r_au, result.ephemeris.r_au)) - all_ephemeris.delta_au = np.concatenate((all_ephemeris.delta_au, result.ephemeris.delta_au)) - all_ephemeris.V_mag = np.concatenate((all_ephemeris.V_mag , result.ephemeris.V_mag )) - all_ephemeris.alpha_deg = np.concatenate((all_ephemeris.alpha_deg, result.ephemeris.alpha_deg)) - all_ephemeris.RSS_3sigma_arcsec = np.concatenate((all_ephemeris.RSS_3sigma_arcsec, result.ephemeris.RSS_3sigma_arcsec)) + all_ephemeris.delta_au = np.concatenate( + (all_ephemeris.delta_au, result.ephemeris.delta_au) + ) + all_ephemeris.V_mag = np.concatenate((all_ephemeris.V_mag, result.ephemeris.V_mag)) + all_ephemeris.alpha_deg = np.concatenate( + (all_ephemeris.alpha_deg, result.ephemeris.alpha_deg) + ) + all_ephemeris.RSS_3sigma_arcsec = np.concatenate( + (all_ephemeris.RSS_3sigma_arcsec, result.ephemeris.RSS_3sigma_arcsec) + ) # Convert to pandas DataFrame - relevant_data = pd.DataFrame({ - 'datetime_jd': all_ephemeris.datetime_jd.jd, - 'datetime_iso': all_ephemeris.datetime_iso.iso, - 'RA': all_ephemeris.RA_deg, - 'DEC': all_ephemeris.DEC_deg, - 'RA_rate': all_ephemeris.RA_rate_arcsec_per_h, - 'DEC_rate': all_ephemeris.DEC_rate_arcsec_per_h, - 'AZ': all_ephemeris.AZ_deg, - 'EL': all_ephemeris.EL_deg, - 'r': all_ephemeris.r_au, - 'delta': all_ephemeris.delta_au, - 'V': all_ephemeris.V_mag, - 'alpha': all_ephemeris.alpha_deg, - 'RSS_3sigma': all_ephemeris.RSS_3sigma_arcsec - }) + relevant_data = pd.DataFrame( + { + "datetime_jd": all_ephemeris.datetime_jd.jd, + "datetime_iso": all_ephemeris.datetime_iso.iso, + "RA": all_ephemeris.RA_deg, + "DEC": all_ephemeris.DEC_deg, + "RA_rate": all_ephemeris.RA_rate_arcsec_per_h, + "DEC_rate": all_ephemeris.DEC_rate_arcsec_per_h, + "AZ": all_ephemeris.AZ_deg, + "EL": all_ephemeris.EL_deg, + "r": all_ephemeris.r_au, + "delta": all_ephemeris.delta_au, + "V": all_ephemeris.V_mag, + "alpha": all_ephemeris.alpha_deg, + "RSS_3sigma": all_ephemeris.RSS_3sigma_arcsec, + } + ) # Generate output filename - output_filename = f"{query.target}_{query.start.iso}_{query.end.iso}.csv".replace(":", "-").replace(" ", "_") + output_filename = f"{query.target}_{query.start.iso}_{query.end.iso}.csv".replace( + ":", "-" + ).replace(" ", "_") # Save the data to a CSV file relevant_data.to_csv(output_filename, index=False) cls.logger.info(f"Ephemeris data successfully saved to {output_filename}") total_end_time = time.time() - cls.logger.info(f"Total time taken for processing the CSV file:" - f"{total_end_time - total_start_time:.2f} seconds.") + cls.logger.info( + f"Total time taken for processing the CSV file:" + f"{total_end_time - total_start_time:.2f} seconds." + ) except Exception as e: cls.logger.error(f"An error occurred while processing the CSV file: {e}") + # Example usage if __name__ == "__main__": - HorizonsInterface.query_ephemeris_from_csv('./data/targets.csv') + HorizonsInterface.query_ephemeris_from_csv("./data/targets.csv") # Different observer location # HorizonsInterface.query_ephemeris_from_csv('targets.csv', observer_location='500') diff --git a/tests/forcedphot/test_horizons_interface.py b/tests/forcedphot/test_horizons_interface.py index 9647ed6..33c24aa 100644 --- a/tests/forcedphot/test_horizons_interface.py +++ b/tests/forcedphot/test_horizons_interface.py @@ -4,32 +4,42 @@ import pandas as pd import pytest from astropy.time import Time -from horizons_interface import HorizonsInterface -from local_dataclasses import EphemerisData, QueryInput, QueryResult +from forcedphot import horizons_interface, local_dataclasses + @pytest.fixture def mock_horizons(): - with patch("horizons_interface.Horizons") as mock: + """ + Fixture to mock the Horizons class for testing. + """ + with patch("forcedphot.horizons_interface.Horizons") as mock: yield mock + @pytest.fixture def mock_csv_data(): - return pd.DataFrame({ - "target": ["Ceres"], - "start": ["2020-01-01"], - "end": ["2020-01-02"], - "step": ["1h"] - }) + """ + Fixture to provide mock CSV data for testing. + """ + return pd.DataFrame({"target": ["Ceres"], "start": ["2020-01-01"], "end": ["2020-01-02"], "step": ["1h"]}) + def test_init(): - hi = HorizonsInterface() - assert hi.observer_location == HorizonsInterface.DEFAULT_OBSERVER_LOCATION + """ + Test the initialization of HorizonsInterface with default and custom observer locations. + """ + hi = horizons_interface.HorizonsInterface() + assert hi.observer_location == horizons_interface.HorizonsInterface.DEFAULT_OBSERVER_LOCATION custom_location = "X06" - hi_custom = HorizonsInterface(observer_location=custom_location) + hi_custom = horizons_interface.HorizonsInterface(observer_location=custom_location) assert hi_custom.observer_location == custom_location + def test_query_single_range_success(mock_horizons): + """ + Test successful query of a single range using mocked Horizons data. + """ mock_ephemerides = MagicMock() mock_ephemerides.return_value = { "datetime_jd": [2459000.5], @@ -43,54 +53,73 @@ def test_query_single_range_success(mock_horizons): "delta": [0.8], "V": [15.0], "alpha": [30.0], - "RSS_3sigma": [0.1] + "RSS_3sigma": [0.1], } mock_horizons.return_value.ephemerides = mock_ephemerides - hi = HorizonsInterface() - query = QueryInput("Ceres", Time("2020-01-01"), Time("2020-01-02"), "1h") + hi = horizons_interface.HorizonsInterface() + query = local_dataclasses.QueryInput("Ceres", Time("2020-01-01"), Time("2020-01-02"), "1h") result = hi.query_single_range(query) assert result is not None assert result.target == "Ceres" assert result.start == Time("2020-01-01") assert result.end == Time("2020-01-02") - assert isinstance(result.ephemeris, EphemerisData) + assert isinstance(result.ephemeris, local_dataclasses.EphemerisData) + def test_query_single_range_failure(mock_horizons): + """ + Test failure handling when querying a single range with an invalid target. + """ mock_horizons.side_effect = Exception("Query failed") - hi = HorizonsInterface() - query = QueryInput("Invalid Target", Time("2020-01-01"), Time("2020-01-02"), "1h") + hi = horizons_interface.HorizonsInterface() + query = local_dataclasses.QueryInput("Invalid Target", Time("2020-01-01"), Time("2020-01-02"), "1h") result = hi.query_single_range(query) assert result is None -@pytest.mark.parametrize("target,start,end,step", [ - ("Ceres", "2020-01-01", "2020-01-02", "1h"), - ("2021 XY", "2021-06-01", "2021-06-30", "2h"), -]) + +@pytest.mark.parametrize( + "target,start,end,step", + [ + ("Ceres", "2020-01-01", "2020-01-02", "1h"), + ("2021 XY", "2021-06-01", "2021-06-30", "2h"), + ], +) def test_query_input_creation(target, start, end, step): - query = QueryInput(target, Time(start), Time(end), step) + """ + Test creation of QueryInput objects with various parameters. + """ + query = local_dataclasses.QueryInput(target, Time(start), Time(end), step) assert query.target == target assert query.start == Time(start) assert query.end == Time(end) assert query.step == step + def test_ephemeris_data_creation(): - ephemeris = EphemerisData() + """ + Test creation of EphemerisData object and verify its attribute types. + """ + ephemeris = local_dataclasses.EphemerisData() assert isinstance(ephemeris.datetime_jd, Time) assert isinstance(ephemeris.datetime_iso, Time) assert isinstance(ephemeris.RA_deg, np.ndarray) assert isinstance(ephemeris.DEC_deg, np.ndarray) # Add similar assertions for other attributes + @patch("pandas.read_csv") -@patch("horizons_interface.HorizonsInterface.query_single_range") +@patch("forcedphot.horizons_interface.HorizonsInterface.query_single_range") def test_query_ephemeris_from_csv(mock_query_single_range, mock_read_csv, mock_csv_data): + """ + Test querying ephemeris data from a CSV file using mocked dependencies. + """ mock_read_csv.return_value = mock_csv_data - mock_ephemeris = EphemerisData( + mock_ephemeris = local_dataclasses.EphemerisData( datetime_jd=Time([2459000.5], format="jd"), datetime_iso=Time(["2020-01-01 00:00:00.000"], format="iso"), RA_deg=np.array([100.0]), @@ -103,13 +132,15 @@ def test_query_ephemeris_from_csv(mock_query_single_range, mock_read_csv, mock_c delta_au=np.array([0.8]), V_mag=np.array([15.0]), alpha_deg=np.array([30.0]), - RSS_3sigma_arcsec=np.array([0.1]) + RSS_3sigma_arcsec=np.array([0.1]), + ) + mock_query_result = local_dataclasses.QueryResult( + "Ceres", Time("2020-01-01"), Time("2020-01-02"), mock_ephemeris ) - mock_query_result = QueryResult("Ceres", Time("2020-01-01"), Time("2020-01-02"), mock_ephemeris) mock_query_single_range.return_value = mock_query_result with patch("builtins.open", create=True), patch("pandas.DataFrame.to_csv") as mock_to_csv: - HorizonsInterface.query_ephemeris_from_csv("test.csv") + horizons_interface.HorizonsInterface.query_ephemeris_from_csv("test.csv") mock_read_csv.assert_called_once_with("test.csv") mock_query_single_range.assert_called_once()