Skip to content

Commit

Permalink
🎨 Format Python code with psf/black
Browse files Browse the repository at this point in the history
  • Loading branch information
Cloud1e authored and wangyinz committed Oct 25, 2024
1 parent 649f6d7 commit 90c58d4
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 70 deletions.
27 changes: 19 additions & 8 deletions python/mspasspy/algorithms/ml/arrival.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def annotate_arrival_time(

# Check the input arguments
if not 0 <= threshold <= 1:
logging.warning('Threshold should be in the range of [0, 1]. Using default threshold {}}'.format(default_threshold))
logging.warning(
"Threshold should be in the range of [0, 1]. Using default threshold {}}".format(
default_threshold
)
)
threshold = default_threshold

# convert timeseries to absolute time
Expand All @@ -51,31 +55,38 @@ def annotate_arrival_time(
# 'stead' model was trained on STEAD for 100 epochs with a learning rate of 0.01.
# use sbm.PhaseNet.list_pretrained(details=True) to list out other supported models
# when using this model, please reference the SeisBench publications listed at https://github.com/seisbench/seisbench
pretrained_model = "stead" if (model_args == None or "name" not in model_args) else model_args["name"]
pretrained_model = (
"stead"
if (model_args == None or "name" not in model_args)
else model_args["name"]
)
model = sbm.PhaseNet.from_pretrained(pretrained_model)

ts_ensemble = TimeSeriesEnsemble()
ts_ensemble.member.append(timeseries)
stream = ts_ensemble.toStream()

# apply the window if provided and convert time series to stream
start_time_utc = stream[0].stats.starttime.timestamp # UTC timestamp
end_time_utc = stream[0].stats.endtime.timestamp # UTC timestamp
start_time_utc = stream[0].stats.starttime.timestamp # UTC timestamp
end_time_utc = stream[0].stats.endtime.timestamp # UTC timestamp

# adjust the time window if it is out of the time range of the time series
if time_window:
if time_window.end < start_time_utc or time_window.start > end_time_utc:
time_window.start = start_time_utc
time_window.end = end_time_utc
logging.warning('Time window is out of the time range of the time series. Adjusting the time window to the time range of the time series.')
logging.warning(
"Time window is out of the time range of the time series. Adjusting the time window to the time range of the time series."
)
if time_window.end > end_time_utc:
time_window.end = end_time_utc
if time_window.start < start_time_utc:
time_window.start = start_time_utc

windowed_stream = (
stream.trim(UTCDateTime(time_window.start), UTCDateTime(time_window.end)) \
if time_window else stream
stream.trim(UTCDateTime(time_window.start), UTCDateTime(time_window.end))
if time_window
else stream
)

# prediction result is the probability for picks over time
Expand Down
32 changes: 16 additions & 16 deletions python/mspasspy/algorithms/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def resample(self, mspass_object):
# We do this test at the top to avoid having returns testing for
# a dead datum in each of the if conditional blocks below
if isinstance(
mspass_object,
(TimeSeries, Seismogram, TimeSeriesEnsemble, SeismogramEnsemble),
mspass_object,
(TimeSeries, Seismogram, TimeSeriesEnsemble, SeismogramEnsemble),
):
if mspass_object.dead():
return mspass_object
Expand All @@ -221,7 +221,7 @@ def resample(self, mspass_object):

if isinstance(mspass_object, TimeSeries):
data_time_span = (
mspass_object.endtime() - mspass_object.t0 + mspass_object.dt
mspass_object.endtime() - mspass_object.t0 + mspass_object.dt
)
n_resampled = int(data_time_span * self.samprate)
rsdata = signal.resample(
Expand All @@ -236,7 +236,7 @@ def resample(self, mspass_object):
mspass_object.data = dv
elif isinstance(mspass_object, Seismogram):
data_time_span = (
mspass_object.endtime() - mspass_object.t0 + mspass_object.dt
mspass_object.endtime() - mspass_object.t0 + mspass_object.dt
)
n_resampled = int(data_time_span * self.samprate)
rsdata = signal.resample(
Expand Down Expand Up @@ -348,8 +348,8 @@ def resample(self, mspass_object):
# We do this test at the top to avoid having returns testing for
# a dead datum in each of the if conditional blocks below
if isinstance(
mspass_object,
(TimeSeries, Seismogram, TimeSeriesEnsemble, SeismogramEnsemble),
mspass_object,
(TimeSeries, Seismogram, TimeSeriesEnsemble, SeismogramEnsemble),
):
if mspass_object.dead():
return mspass_object
Expand Down Expand Up @@ -426,16 +426,16 @@ def resample(self, mspass_object):

@mspass_func_wrapper
def resample(
mspass_object,
decimator,
resampler,
verify_operators=True,
object_history=False,
alg_name="resample",
alg_id=None,
dryrun=False,
inplace_return=False,
function_return_key=None,
mspass_object,
decimator,
resampler,
verify_operators=True,
object_history=False,
alg_name="resample",
alg_id=None,
dryrun=False,
inplace_return=False,
function_return_key=None,
):
"""
Resample any valid data object to a common sample rate (sample interval).
Expand Down
6 changes: 3 additions & 3 deletions python/mspasspy/algorithms/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,8 +889,8 @@ def WindowData_autopad(
return dw


# TODO: this function does not support history mechanism because the
# standard decorator is does not support a bound std::vector<TimeSeries>
# TODO: this function does not support history mechanism because the
# standard decorator is does not support a bound std::vector<TimeSeries>
# container. I requires one of the four MsPASS data objects.
def merge(
tsvector,
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def merge(
dead. When set True, gaps will be zeroed and with a record of
gap positions posted to the Metadata of the output. See above
for details.
:param zero_gaps: boolean controlling how gaps are to be handled.
:param zero_gaps: boolean controlling how gaps are to be handled.
See above for details of the algorithm.
:type zero_gaps: boolean (default False)
:param object_history: boolean to enable or disable saving object
Expand Down
4 changes: 1 addition & 3 deletions python/mspasspy/util/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def TimeSeries2Trace(ts):
if abs(sampling_rate - 1.0 / ts.dt) > 1e-6:
# Record inconsistency in error log (elog)
message = "sampling_rate inconsistent with 1/dt; updating to 1/dt"
ts.elog.log_error("TimeSeries2Trace",
message,
ErrorSeverity.Complaint)
ts.elog.log_error("TimeSeries2Trace", message, ErrorSeverity.Complaint)
# Update sampling_rate to 1/dt
ts["sampling_rate"] = 1.0 / ts.dt
dresult.stats["sampling_rate"] = 1.0 / ts.dt
Expand Down
22 changes: 16 additions & 6 deletions python/tests/algorithms/ml/test_arrival.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

pn_model = sbm.PhaseNet.from_pretrained("stead")


def test_annotate_arrival_time():
"""
Test the annotate_arrival_time function.
Expand All @@ -26,15 +27,15 @@ def test_annotate_arrival_time():
# picks from mspass
timeseries = Trace2TimeSeries(stream[0])
annotate_arrival_time(timeseries, 0)
mspass_picks = timeseries["p_wave_picks"] # should be a dictionary
mspass_picks = timeseries["p_wave_picks"] # should be a dictionary
# assert that for each pick, the value is a float number that is between 0 and 1
assert all(0 <= value <= 1 for value in mspass_picks.values())

# picks from seisbench
pn_preds = pn_model.annotate(Stream(stream[0]))
trace = pn_preds[0]
assert trace.stats.channel == "PhaseNet_P"
seis_picks = trace.times("timestamp") # should be an array
seis_picks = trace.times("timestamp") # should be an array

# Convert both to sets of rounded values
mspass_set = set(round(v, 6) for v in mspass_picks.keys())
Expand All @@ -43,6 +44,7 @@ def test_annotate_arrival_time():
# Compare the sets
assert mspass_set == seis_set


def test_annotate_arrival_time_threshold():
"""
Test the annotate_arrival_time function with a threshold.
Expand All @@ -64,6 +66,7 @@ def test_annotate_arrival_time_threshold():
# assert that the number of picks from mspass is less than the number of picks from seisbench
assert len(mspass_picks) < len(seis_picks)


def test_annotate_arrival_time_window():
"""
Test the annotate_arrival_time function with a time window.
Expand All @@ -76,7 +79,9 @@ def test_annotate_arrival_time_window():
timeseries = Trace2TimeSeries(stream[0])
window_start = UTCDateTime(2009, 4, 6, 1, 30).timestamp
window_end = window_start + 1000
annotate_arrival_time(timeseries, threshold = 0, time_window=TimeWindow(window_start, window_end))
annotate_arrival_time(
timeseries, threshold=0, time_window=TimeWindow(window_start, window_end)
)
mspass_picks = timeseries["p_wave_picks"]

assert len(mspass_picks.keys()) > 0
Expand All @@ -97,6 +102,7 @@ def test_annotate_arrival_time_window():
# Compare the sets
assert mspass_set == seis_set


def test_annotate_arrival_time_for_mseed():
"""
Test the annotate_arrival_time function for a mseed file.
Expand All @@ -108,8 +114,10 @@ def test_annotate_arrival_time_for_mseed():
timeseries = Trace2TimeSeries(trace)

window_start = UTCDateTime(2011, 3, 11, 6, 35).timestamp
window_end = window_start + 1200 # 20 minutes
annotate_arrival_time(timeseries, 0.1, time_window=TimeWindow(window_start, window_end))
window_end = window_start + 1200 # 20 minutes
annotate_arrival_time(
timeseries, 0.1, time_window=TimeWindow(window_start, window_end)
)
mspass_picks = timeseries["p_wave_picks"]

# assert the picks are not empty
Expand All @@ -130,14 +138,15 @@ def test_annotate_arrival_time_for_mseed():

# Convert both to sets of rounded values
mspass_set = set(round(v, 6) for v in mspass_picks.keys())
seis_set = set(round(v, 6) for v in seis_picks)
seis_set = set(round(v, 6) for v in seis_picks)

# every pick in mspass should be in seisbench
assert mspass_set.issubset(seis_set)

# assert that the number of picks from mspass is less than the number of picks from seisbench
assert len(mspass_picks) < len(seis_picks)


def get_mseed_trace_for_test():
file_path = os.path.join(os.getcwd(), "python/tests/data/db_mseeds/test_277.mseed")
st = read(file_path)
Expand All @@ -158,5 +167,6 @@ def get_trace_for_test():
endtime=t + 3600,
)


if __name__ == "__main__":
test_annotate_arrival_time_window()
56 changes: 30 additions & 26 deletions python/tests/algorithms/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,14 +650,15 @@ def test_windowdata_exceptions():
assert d.data[i, j] == 0.0
assert d.elog.size() == 1


def test_WindowData_autopad():
"""
Tests function added September 2024 that uses a soft test
for the data range. Returns a fixed length signal that
is padded unless the actual time span is too short as
defined by the fractional_mismatch_limit argument.
Tests function added September 2024 that uses a soft test
for the data range. Returns a fixed length signal that
is padded unless the actual time span is too short as
defined by the fractional_mismatch_limit argument.
The function is a higher level function using WindowData
so tests are limited to the features it adds and
so tests are limited to the features it adds and
additional exceptions it handles.
"""
# this duplicates code in test_WindowData
Expand All @@ -677,58 +678,61 @@ def test_WindowData_autopad():
# make copies because WindowData can alter content
ts0 = TimeSeries(ts)
se0 = Seismogram(se)
# first verify the function works correctly if the

# first verify the function works correctly if the
# stime:etime range is entirely within the data
ts = TimeSeries(ts0)
ts = WindowData_autopad(ts,ts.time(10),ts.time(50))
ts = WindowData_autopad(ts, ts.time(10), ts.time(50))
assert ts.live
assert ts.npts == 41

se = Seismogram(se0)
se = WindowData_autopad(se,se.time(10),se.time(50))
se = WindowData_autopad(se, se.time(10), se.time(50))
assert se.live
assert se.npts == 41
# verify automatic zero padding of first and last sample

# verify automatic zero padding of first and last sample
# when stime and etime are specified as that
ts = TimeSeries(ts0)
# the time method works with negative and resolves her
# the time method works with negative and resolves her
# to one sample before start and one sample after end time
ts = WindowData_autopad(ts,ts.time(-1),ts.time(npts))
ts = WindowData_autopad(ts, ts.time(-1), ts.time(npts))
assert ts.live
assert ts.npts == npts + 2
# isclose is not needed here as this is a hard set 0
assert ts.data[0] == 0.0
assert ts.data[npts+1] == 0.0
assert ts.data[npts + 1] == 0.0
# similar for Seismogram
se = Seismogram(se0)
# the time method works with negative and resolves her
# the time method works with negative and resolves her
# to one sample before start and one sample after end time
se = WindowData_autopad(se,se.time(-1),se.time(npts))
se = WindowData_autopad(se, se.time(-1), se.time(npts))
assert se.live
assert se.npts == npts + 2
for k in range(3):
assert se.data[k,0] == 0.0
assert se.data[k,npts+1] == 0.0
# test error handling for this function - not made a separate function
# as it isn't that complex
assert se.data[k, 0] == 0.0
assert se.data[k, npts + 1] == 0.0

# test error handling for this function - not made a separate function
# as it isn't that complex
d_foo = TimeSeriesEnsemble(2)
d_foo.member.append(ts)
d_foo.set_live()
with pytest.raises(TypeError,match="arg0 must be either a TimeSeries or Seismogram object"):
with pytest.raises(
TypeError, match="arg0 must be either a TimeSeries or Seismogram object"
):
# need to send TimeSeriesEnsemble to bypass type check of function decorator
ts = WindowData_autopad(d_foo,se.time(10),se.time(50))
ts = WindowData_autopad(d_foo, se.time(10), se.time(50))

# only test this for TimeSeries - no reason to think it would behave differently for Seismogram
ts = TimeSeries(ts0)
ts = WindowData_autopad(ts,ts.t0,ts.time(3*npts))
ts = WindowData_autopad(ts, ts.t0, ts.time(3 * npts))
assert ts.dead()
# this is actually 2 because it passes through WindowDataAtomic that also posts a log message
# Using >= to make the test more robust in the event that changes
assert ts.elog.size() >= 1


def test_TopMute():
ts = TimeSeries(100)
seis = Seismogram(100)
Expand Down
Loading

0 comments on commit 90c58d4

Please sign in to comment.