Skip to content

Commit

Permalink
fixes for new holidays version
Browse files Browse the repository at this point in the history
  • Loading branch information
Nate Parsons committed Dec 6, 2023
1 parent 1498c9a commit cb9acba
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ class DateToHoliday(TransformPrimitive):
>>> date_to_holiday_canada = DateToHoliday(country='Canada')
>>> dates = pd.Series([datetime(2016, 7, 1),
... datetime(2016, 11, 15),
... datetime(2017, 12, 26),
... datetime(2018, 9, 3)])
>>> date_to_holiday_canada(dates).tolist()
['Canada Day', nan, 'Boxing Day', 'Labour Day']
['Canada Day', nan, 'Labour Day']
"""

name = "date_to_holiday"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class DistanceToHoliday(TransformPrimitive):
We can also control the country in which we're searching for
a holiday.
>>> distance_to_holiday = DistanceToHoliday("Victoria Day", country='Canada')
>>> distance_to_holiday = DistanceToHoliday("Canada Day", country='Canada')
>>> dates = [datetime(2010, 1, 1),
... datetime(2012, 5, 31),
... datetime(2017, 7, 31),
... datetime(2020, 12, 31)]
>>> distance_to_holiday(dates).tolist()
[143, -10, -70, 144]
[181, 31, -30, 182]
"""

name = "distance_to_holiday"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def test_deserializer_uses_common_primitive_instances_with_args(es, tmp_path):
# Test primitive with multiple args - pandas only due to primitive compatibility
if es.dataframe_type == Library.PANDAS:
distance_to_holiday = DistanceToHoliday(
holiday="Victoria Day",
holiday="Canada Day",
country="Canada",
)
features = dfs(
Expand Down Expand Up @@ -491,7 +491,7 @@ def test_deserializer_uses_common_primitive_instances_with_args(es, tmp_path):
assert all(
[f.primitive is new_distance_primitive for f in new_distance_features],
)
assert new_distance_primitive.holiday == "Victoria Day"
assert new_distance_primitive.holiday == "Canada Day"
assert new_distance_primitive.country == "Canada"

# Test primitive with list arg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ def test_valid_country():
[
"2016-07-01",
"2016-11-11",
"2017-12-26",
"2018-09-03",
],
).astype("datetime64[ns]")
answer = ["Canada Day", np.nan, "Boxing Day", "Labour Day"]
answer = ["Canada Day", np.nan, "Labour Day"]
given_answer = date_to_holiday(case).astype("str")
np.testing.assert_array_equal(given_answer, answer)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from datetime import datetime

import holidays
import numpy as np
import pandas as pd
import pytest
from packaging.version import parse

from featuretools.primitives import DistanceToHoliday

Expand All @@ -25,35 +23,6 @@ def test_distanceholiday():
np.testing.assert_array_equal(output, expected)


def test_holiday_out_of_range():
date_to_holiday = DistanceToHoliday("Boxing Day", country="Canada")

array = pd.Series(
[
datetime(2010, 1, 1),
datetime(2012, 5, 31),
datetime(2017, 7, 31),
datetime(2020, 12, 31),
],
)
days_to_boxing_day = -157 if parse(holidays.__version__) >= parse("0.15.0") else 209
edge_case_first_day_of_year = (
-6 if parse(holidays.__version__) >= parse("0.17.0") else np.nan
)
edge_case_last_day_of_year = (
-5 if parse(holidays.__version__) >= parse("0.17.0") else np.nan
)
answer = pd.Series(
[
edge_case_first_day_of_year,
days_to_boxing_day,
148,
edge_case_last_day_of_year,
],
)
pd.testing.assert_series_equal(date_to_holiday(array), answer, check_names=False)


def test_unknown_country_error():
error_text = r"must be one of the available countries.*"
with pytest.raises(ValueError, match=error_text):
Expand Down Expand Up @@ -82,7 +51,7 @@ def test_nat():


def test_valid_country():
distance_to_holiday = DistanceToHoliday("Victoria Day", country="Canada")
distance_to_holiday = DistanceToHoliday("Canada Day", country="Canada")
case = pd.Series(
[
"2010-01-01",
Expand All @@ -91,7 +60,7 @@ def test_valid_country():
"2020-12-31",
],
).astype("datetime64[ns]")
answer = [143, -10, -70, 144]
answer = [181, 31, -30, 182]
given_answer = distance_to_holiday(case).astype("float")
np.testing.assert_array_equal(given_answer, answer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ def test_valid_country():
[
"2016-07-01",
"2016-11-11",
"2017-12-26",
"2018-09-03",
],
).astype("datetime64[ns]")
answer = pd.Series([True, False, True, True])
answer = pd.Series([True, False, True])
given_answer = pd.Series(primitive_func(case))
assert given_answer.equals(answer)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cloudpickle==1.5.0
holidays==0.13
holidays==0.17
numpy==1.21.0
packaging==20.0
pandas==1.5.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ license = {text = "BSD 3-clause"}
requires-python = ">=3.8,<4"
dependencies = [
"cloudpickle >= 1.5.0",
"holidays >= 0.13, < 0.33",
"holidays >= 0.17",
"numpy >= 1.21.0",
"packaging >= 20.0",
"pandas >= 1.5.0",
Expand Down

0 comments on commit cb9acba

Please sign in to comment.