Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/slice intersect multi series #2591

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Darts Contribution Guidelines

## Picking an Issue on Which to Work
Picking an Issue on Which to Work

The backlog of issues and ongoing work is tracked here: https://github.com/unit8co/darts/projects/1
Anyone is welcome to pick an issue from the backlog, work on it and submit a pull request.
Expand Down
161 changes: 53 additions & 108 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from scipy.stats import kurtosis, skew

from darts import TimeSeries, concatenate
from darts.timeseries import intersect
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
from darts.utils.utils import expand_arr, freqs, generate_index
from darts.utils.utils import freqs, generate_index


class TestTimeSeries:
Expand Down Expand Up @@ -603,40 +604,64 @@ def check_intersect(other, start_, end_, freq_):
s_int_idx = series.slice_intersect_times(other, copy=False)
assert s_int.time_index.equals(s_int_idx)

def check_intersect_sequence(series, other, start_, end_, freq_):
intersected_series = intersect([series, other])
s_int = intersected_series[0]
o_int = intersected_series[1]

assert intersected_series == [
series.slice_intersect(other),
other.slice_intersect(series),
]

if start_ is None: # empty slice
assert len(s_int) == 0
assert len(o_int) == 0
return

assert s_int.start_time() == o_int.start_time() == start_
assert s_int.end_time() == o_int.end_time() == end_
assert s_int.freq == o_int.freq == freq_

# slice with exact range
startA = start
endA = end
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
check_intersect_sequence(series, seriesA, start, end, freq_expected)

# entire slice within the range
startA = start + freq
endA = startA + 6 * freq_other
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
check_intersect_sequence(series, seriesA, startA, endA, freq_expected)

# start outside of range
startC = start - 4 * freq
endC = start + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, start, endC, freq_expected)
check_intersect_sequence(series, seriesC, start, endC, freq_expected)

# end outside of range
startC = start + 4 * freq
endC = end + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, startC, end, freq_expected)
check_intersect_sequence(series, seriesC, startC, end, freq_expected)

# small intersect
startE = start + (n_steps - 1) * freq
endE = startE + 2 * freq_other
idxE = generate_index(startE, endE, freq=freq_other)
seriesE = TimeSeries.from_series(pd.Series(range(len(idxE)), index=idxE))
check_intersect(seriesE, startE, end, freq_expected)
check_intersect_sequence(series, seriesE, startE, end, freq_expected)

# No intersect
startG = end + 3 * freq
Expand All @@ -645,6 +670,10 @@ def check_intersect(other, start_, end_, freq_):
seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG))
# for empty slices, we expect the original freq
check_intersect(seriesG, None, None, freq)
check_intersect_sequence(series, seriesG, None, None, freq)

# Empty sequence
assert intersect([]) == []

@staticmethod
def helper_test_shift(test_case, test_series: TimeSeries):
Expand Down Expand Up @@ -762,9 +791,6 @@ def helper_test_prepend_values(test_case, test_series: TimeSeries):
assert test_series.time_index.equals(prepended_sq.time_index)
assert prepended_sq.components.equals(test_series.components)

# component and sample dimension should match
assert prepended._xa.shape[1:] == test_series._xa.shape[1:]

def test_slice(self):
TestTimeSeries.helper_test_slice(self, self.series1)

Expand Down Expand Up @@ -800,112 +826,18 @@ def test_append(self):
assert appended.time_index.equals(expected_idx)
assert appended.components.equals(series_1.components)

@pytest.mark.parametrize(
"config",
itertools.product(
[
( # univariate array
np.array([0, 1, 2]).reshape((3, 1, 1)),
np.array([0, 1]).reshape((2, 1, 1)),
),
( # multivariate array
np.array([0, 1, 2, 3, 4, 5]).reshape((3, 2, 1)),
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
),
( # empty array
np.array([0, 1, 2]).reshape((3, 1, 1)),
np.array([]).reshape((0, 1, 1)),
),
(
# wrong number of components
np.array([0, 1, 2]).reshape((3, 1, 1)),
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
),
(
# wrong number of samples
np.array([0, 1, 2]).reshape((3, 1, 1)),
np.array([0, 1, 2, 3]).reshape((2, 1, 2)),
),
( # univariate list with times
np.array([0, 1, 2]).reshape((3, 1, 1)),
[0, 1],
),
( # univariate list with times and components
np.array([0, 1, 2]).reshape((3, 1, 1)),
[[0], [1]],
),
( # univariate list with times, components and samples
np.array([0, 1, 2]).reshape((3, 1, 1)),
[[[0]], [[1]]],
),
( # multivar with list has wrong shape
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
[[1, 2], [3, 4]],
),
( # list with wrong number of components
np.array([0, 1, 2]).reshape((3, 1, 1)),
[[1, 2], [3, 4]],
),
( # list with wrong number of samples
np.array([0, 1, 2]).reshape((3, 1, 1)),
[[[0, 1]], [[1, 2]]],
),
( # multivar input but list has wrong shape
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
[1, 2],
),
],
[True, False],
["append_values", "prepend_values"],
),
)
def test_append_and_prepend_values(self, config):
(series_vals, vals), is_datetime, method = config
start = "20240101" if is_datetime else 1
series_idx = generate_index(
start=start, length=len(series_vals), name="some_name"
)
series = TimeSeries.from_times_and_values(
times=series_idx,
values=series_vals,
def test_append_values(self):
TestTimeSeries.helper_test_append_values(self, self.series1)
# Check `append_values` deals with `RangeIndex` series correctly:
series = linear_timeseries(start=1, length=5, freq=2)
appended = series.append_values(np.ones((2, 1, 1)))
expected_vals = np.concatenate(
[series.all_values(), np.ones((2, 1, 1))], axis=0
)

# expand if it's a list
vals_arr = np.array(vals) if isinstance(vals, list) else vals
vals_arr = expand_arr(vals_arr, ndim=3)

ts_method = getattr(TimeSeries, method)

if vals_arr.shape[1:] != series_vals.shape[1:]:
with pytest.raises(ValueError) as exc:
_ = ts_method(series, vals)
assert str(exc.value).startswith(
"The (expanded) values must have the same number of components and samples"
)
return

appended = ts_method(series, vals)

if method == "append_values":
expected_vals = np.concatenate([series_vals, vals_arr], axis=0)
expected_idx = generate_index(
start=series.start_time(),
length=len(series_vals) + len(vals),
freq=series.freq,
)
else:
expected_vals = np.concatenate([vals_arr, series_vals], axis=0)
expected_idx = generate_index(
end=series.end_time(),
length=len(series_vals) + len(vals),
freq=series.freq,
)

expected_idx = pd.RangeIndex(start=1, stop=15, step=2)
assert np.allclose(appended.all_values(), expected_vals)
assert appended.time_index.equals(expected_idx)
assert appended.components.equals(series.components)
assert appended._xa.shape[1:] == series._xa.shape[1:]
assert appended.time_index.name == series.time_index.name

def test_prepend(self):
TestTimeSeries.helper_test_prepend(self, self.series1)
Expand All @@ -921,6 +853,19 @@ def test_prepend(self):
assert prepended.time_index.equals(expected_idx)
assert prepended.components.equals(series_1.components)

def test_prepend_values(self):
TestTimeSeries.helper_test_prepend_values(self, self.series1)
# Check `prepend_values` deals with `RangeIndex` series correctly:
series = linear_timeseries(start=1, length=5, freq=2)
prepended = series.prepend_values(np.ones((2, 1, 1)))
expected_vals = np.concatenate(
[np.ones((2, 1, 1)), series.all_values()], axis=0
)
expected_idx = pd.RangeIndex(start=-3, stop=11, step=2)
assert np.allclose(prepended.all_values(), expected_vals)
assert prepended.time_index.equals(expected_idx)
assert prepended.components.equals(series.components)

@pytest.mark.parametrize(
"config",
[
Expand Down Expand Up @@ -2432,8 +2377,8 @@ def test_time_col_with_tz(self):
assert list(ts.time_index.tz_localize("CET")) == list(time_range_H)
assert ts.time_index.tz is None

series = pd.Series(data=values, index=time_range_H)
ts = TimeSeries.from_series(pd_series=series)
serie = pd.Series(data=values, index=time_range_H)
ts = TimeSeries.from_series(pd_series=serie)
assert list(ts.time_index) == list(time_range_H.tz_localize(None))
assert list(ts.time_index.tz_localize("CET")) == list(time_range_H)
assert ts.time_index.tz is None
Expand Down
33 changes: 33 additions & 0 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5659,6 +5659,39 @@ def concatenate(
return TimeSeries.from_xarray(da_concat, fill_missing_dates=False)


def intersect(series: Sequence[TimeSeries]):
"""Returns the intersection with respect to the time index of multiple ``TimeSeries``.

Parameters
----------
series : Sequence[TimeSeries]
sequence of ``TimeSeries`` to intersect

Returns
-------
Sequence[TimeSeries]
Intersected series
"""
if not series:
return []

data_arrays = []
has_datetime_index = series[0].has_datetime_index
for ts in series:
if ts.has_datetime_index != has_datetime_index:
raise_log(
IndexError(
"The time index type must be the same for all TimeSeries in the Sequence."
),
logger,
)
data_arrays.append(ts.data_array(copy=False))

intersected_series = xr.align(*data_arrays, exclude=["component", "sample"])

return [TimeSeries.from_xarray(array) for array in intersected_series]


def _finite_rows_boundaries(
values: np.ndarray, how: str = "all"
) -> tuple[Optional[int], Optional[int]]:
Expand Down
Loading