From 90c58d492badb6a5230f5ea704474c45512488fc Mon Sep 17 00:00:00 2001 From: Cloud1e Date: Thu, 24 Oct 2024 21:52:48 +0000 Subject: [PATCH] :art: Format Python code with psf/black --- python/mspasspy/algorithms/ml/arrival.py | 27 +++++++---- python/mspasspy/algorithms/resample.py | 32 ++++++------- python/mspasspy/algorithms/window.py | 6 +-- python/mspasspy/util/converter.py | 4 +- python/tests/algorithms/ml/test_arrival.py | 22 ++++++--- python/tests/algorithms/test_window.py | 56 ++++++++++++---------- python/tests/util/test_converter.py | 22 ++++++--- setup.py | 2 +- 8 files changed, 101 insertions(+), 70 deletions(-) diff --git a/python/mspasspy/algorithms/ml/arrival.py b/python/mspasspy/algorithms/ml/arrival.py index 49eef1fe9..3362ac531 100644 --- a/python/mspasspy/algorithms/ml/arrival.py +++ b/python/mspasspy/algorithms/ml/arrival.py @@ -40,7 +40,11 @@ def annotate_arrival_time( # Check the input arguments if not 0 <= threshold <= 1: - logging.warning('Threshold should be in the range of [0, 1]. Using default threshold {}}'.format(default_threshold)) + logging.warning( + "Threshold should be in the range of [0, 1]. Using default threshold {}}".format( + default_threshold + ) + ) threshold = default_threshold # convert timeseries to absolute time @@ -51,7 +55,11 @@ def annotate_arrival_time( # 'stead' model was trained on STEAD for 100 epochs with a learning rate of 0.01. # use sbm.PhaseNet.list_pretrained(details=True) to list out other supported models # when using this model, please reference the SeisBench publications listed at https://github.com/seisbench/seisbench - pretrained_model = "stead" if (model_args == None or "name" not in model_args) else model_args["name"] + pretrained_model = ( + "stead" + if (model_args == None or "name" not in model_args) + else model_args["name"] + ) model = sbm.PhaseNet.from_pretrained(pretrained_model) ts_ensemble = TimeSeriesEnsemble() @@ -59,23 +67,26 @@ def annotate_arrival_time( stream = ts_ensemble.toStream() # apply the window if provided and convert time series to stream - start_time_utc = stream[0].stats.starttime.timestamp # UTC timestamp - end_time_utc = stream[0].stats.endtime.timestamp # UTC timestamp + start_time_utc = stream[0].stats.starttime.timestamp # UTC timestamp + end_time_utc = stream[0].stats.endtime.timestamp # UTC timestamp # adjust the time window if it is out of the time range of the time series if time_window: if time_window.end < start_time_utc or time_window.start > end_time_utc: time_window.start = start_time_utc time_window.end = end_time_utc - logging.warning('Time window is out of the time range of the time series. Adjusting the time window to the time range of the time series.') + logging.warning( + "Time window is out of the time range of the time series. Adjusting the time window to the time range of the time series." + ) if time_window.end > end_time_utc: time_window.end = end_time_utc if time_window.start < start_time_utc: time_window.start = start_time_utc - + windowed_stream = ( - stream.trim(UTCDateTime(time_window.start), UTCDateTime(time_window.end)) \ - if time_window else stream + stream.trim(UTCDateTime(time_window.start), UTCDateTime(time_window.end)) + if time_window + else stream ) # prediction result is the probability for picks over time diff --git a/python/mspasspy/algorithms/resample.py b/python/mspasspy/algorithms/resample.py index 1386d074b..e37c7835e 100755 --- a/python/mspasspy/algorithms/resample.py +++ b/python/mspasspy/algorithms/resample.py @@ -208,8 +208,8 @@ def resample(self, mspass_object): # We do this test at the top to avoid having returns testing for # a dead datum in each of the if conditional blocks below if isinstance( - mspass_object, - (TimeSeries, Seismogram, TimeSeriesEnsemble, SeismogramEnsemble), + mspass_object, + (TimeSeries, Seismogram, TimeSeriesEnsemble, SeismogramEnsemble), ): if mspass_object.dead(): return mspass_object @@ -221,7 +221,7 @@ def resample(self, mspass_object): if isinstance(mspass_object, TimeSeries): data_time_span = ( - mspass_object.endtime() - mspass_object.t0 + mspass_object.dt + mspass_object.endtime() - mspass_object.t0 + mspass_object.dt ) n_resampled = int(data_time_span * self.samprate) rsdata = signal.resample( @@ -236,7 +236,7 @@ def resample(self, mspass_object): mspass_object.data = dv elif isinstance(mspass_object, Seismogram): data_time_span = ( - mspass_object.endtime() - mspass_object.t0 + mspass_object.dt + mspass_object.endtime() - mspass_object.t0 + mspass_object.dt ) n_resampled = int(data_time_span * self.samprate) rsdata = signal.resample( @@ -348,8 +348,8 @@ def resample(self, mspass_object): # We do this test at the top to avoid having returns testing for # a dead datum in each of the if conditional blocks below if isinstance( - mspass_object, - (TimeSeries, Seismogram, TimeSeriesEnsemble, SeismogramEnsemble), + mspass_object, + (TimeSeries, Seismogram, TimeSeriesEnsemble, SeismogramEnsemble), ): if mspass_object.dead(): return mspass_object @@ -426,16 +426,16 @@ def resample(self, mspass_object): @mspass_func_wrapper def resample( - mspass_object, - decimator, - resampler, - verify_operators=True, - object_history=False, - alg_name="resample", - alg_id=None, - dryrun=False, - inplace_return=False, - function_return_key=None, + mspass_object, + decimator, + resampler, + verify_operators=True, + object_history=False, + alg_name="resample", + alg_id=None, + dryrun=False, + inplace_return=False, + function_return_key=None, ): """ Resample any valid data object to a common sample rate (sample interval). diff --git a/python/mspasspy/algorithms/window.py b/python/mspasspy/algorithms/window.py index 1d25f849f..0642c7679 100644 --- a/python/mspasspy/algorithms/window.py +++ b/python/mspasspy/algorithms/window.py @@ -889,8 +889,8 @@ def WindowData_autopad( return dw -# TODO: this function does not support history mechanism because the -# standard decorator is does not support a bound std::vector +# TODO: this function does not support history mechanism because the +# standard decorator is does not support a bound std::vector # container. I requires one of the four MsPASS data objects. def merge( tsvector, @@ -1040,7 +1040,7 @@ def merge( dead. When set True, gaps will be zeroed and with a record of gap positions posted to the Metadata of the output. See above for details. - :param zero_gaps: boolean controlling how gaps are to be handled. + :param zero_gaps: boolean controlling how gaps are to be handled. See above for details of the algorithm. :type zero_gaps: boolean (default False) :param object_history: boolean to enable or disable saving object diff --git a/python/mspasspy/util/converter.py b/python/mspasspy/util/converter.py index 2dc75a9eb..6e80fbe98 100644 --- a/python/mspasspy/util/converter.py +++ b/python/mspasspy/util/converter.py @@ -147,9 +147,7 @@ def TimeSeries2Trace(ts): if abs(sampling_rate - 1.0 / ts.dt) > 1e-6: # Record inconsistency in error log (elog) message = "sampling_rate inconsistent with 1/dt; updating to 1/dt" - ts.elog.log_error("TimeSeries2Trace", - message, - ErrorSeverity.Complaint) + ts.elog.log_error("TimeSeries2Trace", message, ErrorSeverity.Complaint) # Update sampling_rate to 1/dt ts["sampling_rate"] = 1.0 / ts.dt dresult.stats["sampling_rate"] = 1.0 / ts.dt diff --git a/python/tests/algorithms/ml/test_arrival.py b/python/tests/algorithms/ml/test_arrival.py index 9d0b0c9f5..be3d2ff08 100644 --- a/python/tests/algorithms/ml/test_arrival.py +++ b/python/tests/algorithms/ml/test_arrival.py @@ -15,6 +15,7 @@ pn_model = sbm.PhaseNet.from_pretrained("stead") + def test_annotate_arrival_time(): """ Test the annotate_arrival_time function. @@ -26,7 +27,7 @@ def test_annotate_arrival_time(): # picks from mspass timeseries = Trace2TimeSeries(stream[0]) annotate_arrival_time(timeseries, 0) - mspass_picks = timeseries["p_wave_picks"] # should be a dictionary + mspass_picks = timeseries["p_wave_picks"] # should be a dictionary # assert that for each pick, the value is a float number that is between 0 and 1 assert all(0 <= value <= 1 for value in mspass_picks.values()) @@ -34,7 +35,7 @@ def test_annotate_arrival_time(): pn_preds = pn_model.annotate(Stream(stream[0])) trace = pn_preds[0] assert trace.stats.channel == "PhaseNet_P" - seis_picks = trace.times("timestamp") # should be an array + seis_picks = trace.times("timestamp") # should be an array # Convert both to sets of rounded values mspass_set = set(round(v, 6) for v in mspass_picks.keys()) @@ -43,6 +44,7 @@ def test_annotate_arrival_time(): # Compare the sets assert mspass_set == seis_set + def test_annotate_arrival_time_threshold(): """ Test the annotate_arrival_time function with a threshold. @@ -64,6 +66,7 @@ def test_annotate_arrival_time_threshold(): # assert that the number of picks from mspass is less than the number of picks from seisbench assert len(mspass_picks) < len(seis_picks) + def test_annotate_arrival_time_window(): """ Test the annotate_arrival_time function with a time window. @@ -76,7 +79,9 @@ def test_annotate_arrival_time_window(): timeseries = Trace2TimeSeries(stream[0]) window_start = UTCDateTime(2009, 4, 6, 1, 30).timestamp window_end = window_start + 1000 - annotate_arrival_time(timeseries, threshold = 0, time_window=TimeWindow(window_start, window_end)) + annotate_arrival_time( + timeseries, threshold=0, time_window=TimeWindow(window_start, window_end) + ) mspass_picks = timeseries["p_wave_picks"] assert len(mspass_picks.keys()) > 0 @@ -97,6 +102,7 @@ def test_annotate_arrival_time_window(): # Compare the sets assert mspass_set == seis_set + def test_annotate_arrival_time_for_mseed(): """ Test the annotate_arrival_time function for a mseed file. @@ -108,8 +114,10 @@ def test_annotate_arrival_time_for_mseed(): timeseries = Trace2TimeSeries(trace) window_start = UTCDateTime(2011, 3, 11, 6, 35).timestamp - window_end = window_start + 1200 # 20 minutes - annotate_arrival_time(timeseries, 0.1, time_window=TimeWindow(window_start, window_end)) + window_end = window_start + 1200 # 20 minutes + annotate_arrival_time( + timeseries, 0.1, time_window=TimeWindow(window_start, window_end) + ) mspass_picks = timeseries["p_wave_picks"] # assert the picks are not empty @@ -130,7 +138,7 @@ def test_annotate_arrival_time_for_mseed(): # Convert both to sets of rounded values mspass_set = set(round(v, 6) for v in mspass_picks.keys()) - seis_set = set(round(v, 6) for v in seis_picks) + seis_set = set(round(v, 6) for v in seis_picks) # every pick in mspass should be in seisbench assert mspass_set.issubset(seis_set) @@ -138,6 +146,7 @@ def test_annotate_arrival_time_for_mseed(): # assert that the number of picks from mspass is less than the number of picks from seisbench assert len(mspass_picks) < len(seis_picks) + def get_mseed_trace_for_test(): file_path = os.path.join(os.getcwd(), "python/tests/data/db_mseeds/test_277.mseed") st = read(file_path) @@ -158,5 +167,6 @@ def get_trace_for_test(): endtime=t + 3600, ) + if __name__ == "__main__": test_annotate_arrival_time_window() diff --git a/python/tests/algorithms/test_window.py b/python/tests/algorithms/test_window.py index 178195deb..12e560660 100644 --- a/python/tests/algorithms/test_window.py +++ b/python/tests/algorithms/test_window.py @@ -650,14 +650,15 @@ def test_windowdata_exceptions(): assert d.data[i, j] == 0.0 assert d.elog.size() == 1 + def test_WindowData_autopad(): """ - Tests function added September 2024 that uses a soft test - for the data range. Returns a fixed length signal that - is padded unless the actual time span is too short as - defined by the fractional_mismatch_limit argument. + Tests function added September 2024 that uses a soft test + for the data range. Returns a fixed length signal that + is padded unless the actual time span is too short as + defined by the fractional_mismatch_limit argument. The function is a higher level function using WindowData - so tests are limited to the features it adds and + so tests are limited to the features it adds and additional exceptions it handles. """ # this duplicates code in test_WindowData @@ -677,58 +678,61 @@ def test_WindowData_autopad(): # make copies because WindowData can alter content ts0 = TimeSeries(ts) se0 = Seismogram(se) - - # first verify the function works correctly if the + + # first verify the function works correctly if the # stime:etime range is entirely within the data ts = TimeSeries(ts0) - ts = WindowData_autopad(ts,ts.time(10),ts.time(50)) + ts = WindowData_autopad(ts, ts.time(10), ts.time(50)) assert ts.live assert ts.npts == 41 - + se = Seismogram(se0) - se = WindowData_autopad(se,se.time(10),se.time(50)) + se = WindowData_autopad(se, se.time(10), se.time(50)) assert se.live assert se.npts == 41 - - # verify automatic zero padding of first and last sample + + # verify automatic zero padding of first and last sample # when stime and etime are specified as that ts = TimeSeries(ts0) - # the time method works with negative and resolves her + # the time method works with negative and resolves her # to one sample before start and one sample after end time - ts = WindowData_autopad(ts,ts.time(-1),ts.time(npts)) + ts = WindowData_autopad(ts, ts.time(-1), ts.time(npts)) assert ts.live assert ts.npts == npts + 2 # isclose is not needed here as this is a hard set 0 assert ts.data[0] == 0.0 - assert ts.data[npts+1] == 0.0 + assert ts.data[npts + 1] == 0.0 # similar for Seismogram se = Seismogram(se0) - # the time method works with negative and resolves her + # the time method works with negative and resolves her # to one sample before start and one sample after end time - se = WindowData_autopad(se,se.time(-1),se.time(npts)) + se = WindowData_autopad(se, se.time(-1), se.time(npts)) assert se.live assert se.npts == npts + 2 for k in range(3): - assert se.data[k,0] == 0.0 - assert se.data[k,npts+1] == 0.0 - - # test error handling for this function - not made a separate function - # as it isn't that complex + assert se.data[k, 0] == 0.0 + assert se.data[k, npts + 1] == 0.0 + + # test error handling for this function - not made a separate function + # as it isn't that complex d_foo = TimeSeriesEnsemble(2) d_foo.member.append(ts) d_foo.set_live() - with pytest.raises(TypeError,match="arg0 must be either a TimeSeries or Seismogram object"): + with pytest.raises( + TypeError, match="arg0 must be either a TimeSeries or Seismogram object" + ): # need to send TimeSeriesEnsemble to bypass type check of function decorator - ts = WindowData_autopad(d_foo,se.time(10),se.time(50)) - + ts = WindowData_autopad(d_foo, se.time(10), se.time(50)) + # only test this for TimeSeries - no reason to think it would behave differently for Seismogram ts = TimeSeries(ts0) - ts = WindowData_autopad(ts,ts.t0,ts.time(3*npts)) + ts = WindowData_autopad(ts, ts.t0, ts.time(3 * npts)) assert ts.dead() # this is actually 2 because it passes through WindowDataAtomic that also posts a log message # Using >= to make the test more robust in the event that changes assert ts.elog.size() >= 1 + def test_TopMute(): ts = TimeSeries(100) seis = Seismogram(100) diff --git a/python/tests/util/test_converter.py b/python/tests/util/test_converter.py index 17f4141f2..3a5e9a6c5 100644 --- a/python/tests/util/test_converter.py +++ b/python/tests/util/test_converter.py @@ -14,11 +14,10 @@ from unittest import mock with mock.patch.dict( - sys.modules, {"pyspark": None, "dask": None, "dask.dataframe": None} + sys.modules, {"pyspark": None, "dask": None, "dask.dataframe": None} ): from mspasspy.util.converter import Textfile2Dataframe - def test_Textfile2Dataframe_no_parallel(): pf = AntelopePf("python/tests/data/test_import.pf") attributes = Pf2AttributeNameTbl(pf, tag="wfprocess") @@ -66,6 +65,7 @@ def test_Textfile2Dataframe_no_parallel(): textfile, header_line=0, parallel=p, insert_column={"test_col": 1} ) + from mspasspy.ccore.utility import dmatrix, Metadata, AntelopePf, MsPASSError from mspasspy.ccore.seismic import DoubleVector, Seismogram, TimeSeries from mspasspy.util.converter import ( @@ -204,8 +204,11 @@ def test_TimeSeries2Trace(): assert tr.stats["npts"] == test_TimeSeries2Trace.ts1.get("npts") assert tr.stats["sampling_rate"] == test_TimeSeries2Trace.ts1.get("sampling_rate") # error log should be empty or the message should not be updateSamplingRateMessage because the sampling_rate is defined and correct - assert test_TimeSeries2Trace.ts1.elog.size() == 0 or not test_TimeSeries2Trace.ts1.elog.get_error_log()[ - 0].message != updateSamplingRateMessage + assert ( + test_TimeSeries2Trace.ts1.elog.size() == 0 + or not test_TimeSeries2Trace.ts1.elog.get_error_log()[0].message + != updateSamplingRateMessage + ) # test for case when "sampling_rate" is not defined in ts1 # create a new copy of ts1 without "sampling_rate" defined ts_size = 255 @@ -229,7 +232,10 @@ def test_TimeSeries2Trace(): assert tr.stats["npts"] == ts1_copy.get("npts") assert tr.stats["sampling_rate"] == ts1_copy.get("sampling_rate") # error log should be empty or the message should not be updateSamplingRateMessage because the sampling_rate is not defined - assert ts1_copy.elog.size() == 0 or not ts1_copy.elog.get_error_log()[0].message != updateSamplingRateMessage + assert ( + ts1_copy.elog.size() == 0 + or not ts1_copy.elog.get_error_log()[0].message != updateSamplingRateMessage + ) # test for "sampling_rate" of ts1_copy assert ts1_copy.is_defined("sampling_rate") @@ -260,8 +266,10 @@ def test_TimeSeries2Trace(): assert tr.stats["npts"] == ts1_copy.get("npts") assert tr.stats["sampling_rate"] == ts1_copy.get("sampling_rate") # message of error log should be updateSamplingRateMessage because the sampling_rate is defined wrongly and need to be updated - assert ts1_copy.elog.get_error_log()[0].algorithm == "TimeSeries2Trace" and ts1_copy.elog.get_error_log()[ - 0].message == updateSamplingRateMessage + assert ( + ts1_copy.elog.get_error_log()[0].algorithm == "TimeSeries2Trace" + and ts1_copy.elog.get_error_log()[0].message == updateSamplingRateMessage + ) def test_Trace2TimeSeries(): diff --git a/setup.py b/setup.py index eed0ac804..660d9f1ec 100644 --- a/setup.py +++ b/setup.py @@ -98,5 +98,5 @@ def build_extension(self, ext): ), package_data={"": ["*.yaml", "*.pf"]}, include_package_data=True, - install_requires = ["pyyaml"] + install_requires=["pyyaml"], )