diff --git a/traffic_prophet/config.py b/traffic_prophet/config.py index 4d84798..0362d70 100644 --- a/traffic_prophet/config.py +++ b/traffic_prophet/config.py @@ -4,7 +4,7 @@ # For countmatch cm = { - 'verbose': True, + 'verbose': False, 'min_count': 96, 'min_counts_in_day': 24, 'min_permanent_months': 12, diff --git a/traffic_prophet/countmatch/derivedvals.py b/traffic_prophet/countmatch/derivedvals.py index 5eef61d..c4fdd37 100644 --- a/traffic_prophet/countmatch/derivedvals.py +++ b/traffic_prophet/countmatch/derivedvals.py @@ -19,7 +19,7 @@ class DVRegistrar(type): def __init__(cls, name, bases, dct): - # Register GrowthFactorBase subclass if `_dv_type` not already taken. + # Register DerivedVal subclass if `_dv_type` not already taken. if name not in ('DerivedValsBase', 'DerivedVals'): if not hasattr(cls, "_dv_type"): raise ValueError("must define a `_dv_type`.") @@ -57,6 +57,19 @@ class DerivedValsBase(metaclass=DVRegistrar): @staticmethod def preprocess_daily_counts(dc): + """Preprocess daily counts by adding month and day of week. + + Parameters + --------------- + dc : pandas.DataFrame + Daily counts. + + Returns + ------- + dca : pandas.DataFrame + Copy of `dc` with 'Month' and 'Day of Week' columns. + + """ # Ensure dc remains unchanged by the method. dca = dc.copy() dca['Month'] = dca['Date'].dt.month @@ -64,6 +77,20 @@ def preprocess_daily_counts(dc): return dca def get_madt(self, dca): + """Get mean average daily traffic (MADT) from processed daily counts. + + Parameters + --------------- + dca : pandas.DataFrame + Daily counts, with 'Month' and 'Day of Week' columns from + `preprocess_daily_counts`. + + Returns + ------- + madt : pandas.DataFrame + MADT table, with 'Days Available' and 'Days in Month' columns. + + """ dc_m = dca.groupby(['Year', 'Month']) madt = pd.DataFrame({ 'MADT': dc_m['Daily Count'].mean(), @@ -73,9 +100,10 @@ def get_madt(self, dca): names=['Year', 'Month']) ) - # Loop to record number of days in month. + # Loop to record the number of days in each month. February requires + # recalculating in case of leap years. days_in_month = [] - for year in dca.index.levels[0]: + for year in madt.index.levels[0]: cdays = self._months_of_year.copy() cdays[1] = pd.to_datetime('{0:d}-02-01'.format(year)).daysinmonth days_in_month += cdays @@ -85,7 +113,22 @@ def get_madt(self, dca): @staticmethod def get_aadt_py_from_madt(madt, perm_years): - # Weighted average for AADT. + """Annual average daily traffic (AADT) from an MADT weighted average. + + Parameters + --------------- + madt : pandas.DataFrame + MADT, with 'Days in Month' column as from `get_madt`. + perm_years : list + List of permanent count years for location; obtained from + PermCount.perm_years. + + Returns + ------- + aadt : pandas.DataFrame + AADT table. + + """ madt_py = madt.loc[perm_years, :] monthly_total_traffic = madt_py['MADT'] * madt_py['Days in Month'] return pd.DataFrame( @@ -93,9 +136,33 @@ def get_aadt_py_from_madt(madt, perm_years): madt_py.groupby('Year')['Days in Month'].sum())}) @staticmethod - def get_ratios_py(dca, madt, aadt, perm_years): - dca_py = dca.loc[perm_years].copy() - dc_dom = dca_py.groupby(['Year', 'Month', 'Day of Week']) + def get_ratios_py(dca, madt, aadt_py, perm_years): + """Ratios between MADT and AADT and day-of-month average daily traffic. + + Parameters + --------------- + dca : pandas.DataFrame + Daily counts, with 'Month' and 'Day of Week' columns from + `preprocess_daily_counts`. + madt : pandas.DataFrame + MADT, with 'Days in Month' column as from `get_madt`. + aadt_py : pandas.DataFrame + AADT for permanent years, as from `get_aadt_py_from_madt`. + perm_years : list + List of permanent count years for location; obtained from + PermCount.perm_years. + + Returns + ------- + dom_ijd : pandas.DataFrame + Ratio between MADT and day-of-month ADT. + d_ijd : pandas.DataFrame + Ratio between AADT and day-of-month ADT. + n_avail_days : pandas.DataFrame + Number of days used to calculate day-of-month ADT. + + """ + dc_dom = dca.loc[perm_years].groupby(['Year', 'Month', 'Day of Week']) # Multi-index levels retain all values from dca even after using loc, # so can only use `perm_years`. ymd_index = pd.MultiIndex.from_product( @@ -112,8 +179,9 @@ def get_ratios_py(dca, madt, aadt, perm_years): # broadcasting trick.) dom_ijd = (madt['MADT'].loc[perm_years, :].values[:, np.newaxis] / domadt) - # Determine day-to-year conversion factor D_ijd. - d_ijd = (aadt['AADT'].values[:, np.newaxis] / + # Determine day-to-year conversion factor D_ijd. (Uses broadcasting + # and pivoting pandas columns.) + d_ijd = (aadt_py['AADT'].values[:, np.newaxis] / domadt.unstack(level=-1)).stack() return dom_ijd, d_ijd, n_avail_days @@ -130,6 +198,16 @@ def __init__(self, impute_ratios=False): self._impute_ratios = impute_ratios def get_derived_vals(self, ptc): + """Get derived values, including ADTs and ratios between them. + + Depending on settings, will also impute missing values. + + Parameters + ---------- + ptc : permcount.PermCount + Permanent count instance. + + """ dca = self.preprocess_daily_counts(ptc.data) madt = self.get_madt(dca) aadt = self.get_aadt_py_from_madt(madt, ptc.perm_years) @@ -139,3 +217,9 @@ def get_derived_vals(self, ptc): ptc.adts = {'MADT': madt, 'AADT': aadt} ptc.ratios = {'DoM_ijd': dom_ijd, 'D_ijd': d_ijd, 'N_avail_days': n_avail_days} + + if self._impute_ratios: + self.impute_ratios(ptc) + + def impute_ratios(self, ptc): + raise NotImplementedError diff --git a/traffic_prophet/countmatch/growthfactor.py b/traffic_prophet/countmatch/growthfactor.py index d7d966b..3a60c6c 100644 --- a/traffic_prophet/countmatch/growthfactor.py +++ b/traffic_prophet/countmatch/growthfactor.py @@ -46,14 +46,39 @@ class GrowthFactorBase(metaclass=GFRegistrar): @staticmethod def get_aadt(tc): - """Get AADT (for permanent years).""" + """Get AADT (for permanent years). + + Parameters + ---------- + tc : permcount.PermCount + Permanent count instance. + + Returns + ------- + aadt : pandas.DataFrame + Modified AADT table where 'Year' is a column, not an index. + + """ aadt = tc.adts['AADT'].reset_index() aadt['Year'] = aadt['Year'].astype(float) return aadt @staticmethod def get_wadt_py(tc): - """Get WADT for all full weeks within permanent years.""" + """Get WADT for all full weeks within permanent years. + + Parameters + ---------- + tc : permcount.PermCount + Permanent count instance. + + Returns + ------- + wadt : pandas.DataFrame + Weekly average daily traffic, and 'Time' in weeks from the start of + the first permanent count year (in `perm_years`). + + """ cdata = tc.data.loc[tc.perm_years, :].copy() # Corrective measure, as dt.week returns the "week ordinal". See @@ -122,6 +147,14 @@ def exponential_rate_fit(year, aadt, ref_vals): return model.fit() def fit_growth(self, tc): + """Fit an exponential growth module to AADT vs. year. + + Parameters + ---------- + tc : permcount.PermCount + Permanent count instance. + + """ # Process year vs. AADT data. aadt = self.get_aadt(tc) @@ -174,7 +207,14 @@ def linear_rate_fit(week, wadt): return model.fit() def fit_growth(self, tc): - # Process week vs. weekly averaged ADT. + """Fit a linear growth model to WADT vs. time in weeks. + + Parameters + ---------- + tc : permcount.PermCount + Permanent count instance. + + """ wadt_py = self.get_wadt_py(tc) fit_results = self.linear_rate_fit(wadt_py['Time'].values, wadt_py['WADT'].values) @@ -183,7 +223,7 @@ def fit_growth(self, tc): # AADTs can only be calculated for perm_years, so consistent with # `get_wadt_py`. growth_factor = 1. + (fit_results.params[1] * 52. / - tc.adts['AADT']['AADT'].values[0]) + tc.adts['AADT'].iat[0, 0]) tc._growth_fit = {'fit_type': 'Linear', 'fit_results': fit_results, @@ -201,6 +241,14 @@ def __init__(self): self.wadt_lin = GrowthFactorWADTLin() def fit_growth(self, tc): + """Fit a linear model for single years, exponential for multiple years. + + Parameters + ---------- + tc : permcount.PermCount + Permanent count instance. + + """ if tc.adts['AADT'].shape[0] > 1: self.aadt_exp.fit_growth(tc) else: diff --git a/traffic_prophet/countmatch/permcount.py b/traffic_prophet/countmatch/permcount.py index 78df42c..458f880 100644 --- a/traffic_prophet/countmatch/permcount.py +++ b/traffic_prophet/countmatch/permcount.py @@ -37,10 +37,12 @@ def __init__(self, count_id, centreline_id, direction, data, @property def perm_years(self): + """List of years where the count is permanent.""" return self._perm_years @property def growth_factor(self): + """Multiplicative year-on-year growth factor.""" if self._growth_fit is None: raise AttributeError("PTC has not had its growth factor fit!") return self._growth_fit['growth_factor'] @@ -63,7 +65,7 @@ class PermCountProcessor: dv_calc : DerivedVal instance For imputation and derived properties. gf_calc : GrowthFactor instance - For estimating growth factor. + For estimating the growth factor. cfg : dict, optional Configuration settings. Default: `config.cm`. """ @@ -72,8 +74,7 @@ def __init__(self, dv_calc, gf_calc, cfg=cfg.cm): self.dvc = dv_calc self.gfc = gf_calc self.cfg = cfg - # Obtain a list of count_ids that (according to TEPs-I) should not be - # PTCs because they reduce the accuracy of CountMatch. + # Obtain a list of count_ids that should be excluded from being PTCs. self.excluded_ids = (self.cfg['exclude_ptc_pos'] + [-id for id in self.cfg['exclude_ptc_neg']]) self._disable_tqdm = not self.cfg['verbose'] diff --git a/traffic_prophet/countmatch/tests/conftest.py b/traffic_prophet/countmatch/tests/conftest.py index b080d84..745f928 100644 --- a/traffic_prophet/countmatch/tests/conftest.py +++ b/traffic_prophet/countmatch/tests/conftest.py @@ -1,6 +1,6 @@ """CountMatch test suite preprocessing. -Fixtures used by multiple files in this and subdirectories must be placed here. +Fixtures used by multiple files in this and subdirectories are placed here. See https://docs.pytest.org/en/latest/fixture.html#conftest-py-sharing-fixture-functions for more. diff --git a/traffic_prophet/countmatch/tests/test_base.py b/traffic_prophet/countmatch/tests/test_base.py index 08c0b04..7eb44bd 100644 --- a/traffic_prophet/countmatch/tests/test_base.py +++ b/traffic_prophet/countmatch/tests/test_base.py @@ -1,3 +1,5 @@ +import operator + from .. import base @@ -7,6 +9,6 @@ def test_count(self): count = base.Count('test', 1, -1., None) assert count.count_id == 'test' assert count.centreline_id == 1 - assert count.direction == -1 + assert operator.index(count.direction) == -1 assert count.data is None assert not count.is_permanent diff --git a/traffic_prophet/countmatch/tests/test_derivedvals.py b/traffic_prophet/countmatch/tests/test_derivedvals.py index 1abd618..413c160 100644 --- a/traffic_prophet/countmatch/tests/test_derivedvals.py +++ b/traffic_prophet/countmatch/tests/test_derivedvals.py @@ -1,5 +1,4 @@ import pytest -import hypothesis as hyp import numpy as np import pandas as pd @@ -7,111 +6,135 @@ from .. import derivedvals as dv -def get_single_ptc(sample_counts, cfgcm_test, count_id): - pcpp = pc.PermCountProcessor(None, None, cfg=cfgcm_test) - perm_years = pcpp.partition_years(sample_counts.counts[count_id]) - ptc = pc.PermCount.from_count_object(sample_counts.counts[count_id], - perm_years) +def get_single_ptc(counts, cfgcm, count_id): + pcpp = pc.PermCountProcessor(None, None, cfg=cfgcm) + perm_years = pcpp.partition_years(counts.counts[count_id]) + ptc = pc.PermCount.from_count_object(counts.counts[count_id], perm_years) return ptc -@pytest.fixture(scope='module') -def ptc_oneyear(sample_counts, cfgcm_test): - return get_single_ptc(sample_counts, cfgcm_test, -890) +class TestDVRegistrarDerivedVals: + """Tests DVRegistrar and DerivedVals.""" + def test_dvregistrar(self): -@pytest.fixture(scope='module') -def ptc_multiyear(sample_counts, cfgcm_test): - return get_single_ptc(sample_counts, cfgcm_test, -104870) + # Test successful initialization of DerivedVals subclass. + class DerivedValsStandardTest(dv.DerivedValsStandard): + _dv_type = 'Testing' + + assert dv.DV_REGISTRY['Testing'] is DerivedValsStandardTest + dv_instance = dv.DerivedVals('Testing') + assert dv_instance._dv_type == 'Testing' + + # Pop the dummy class, in case we test twice. + dv.DVRegistrar._registry.pop('Testing', None) + + # Test repeated `_dv_type` error handling. + with pytest.raises(ValueError) as excinfo: + class DerivedValsStandardBad1(dv.DerivedValsStandard): + pass + assert "already registered in" in str(excinfo.value) + + # Test missing `_dv_type` error handling. + with pytest.raises(ValueError) as excinfo: + class DerivedValsStandardBad2(dv.DerivedValsBase): + pass + assert "must define a" in str(excinfo.value) class TestDerivedValsBase: + @pytest.fixture(params=[-890, -104870]) + def ptc_sample(self, sample_counts, cfgcm_test, request): + return get_single_ptc(sample_counts, cfgcm_test, request.param) + def setup(self): self.dvc = dv.DerivedValsBase() - def test_preprocess_daily_counts(self, ptc_multiyear): - dca = self.dvc.preprocess_daily_counts(ptc_multiyear.data) - assert 'Month' in dca.columns - assert 'Day of Week' in dca.columns - - def test_get_madt(self, ptc_oneyear, ptc_multiyear): - for ptc in (ptc_oneyear, ptc_multiyear): - dca = self.dvc.preprocess_daily_counts(ptc.data) - madt = self.dvc.get_madt(dca) - - madt_ref = pd.DataFrame({ - 'MADT': dca.groupby(['Year', 'Month'])['Daily Count'].mean(), - 'Days Available': dca.groupby( - ['Year', 'Month'])['Daily Count'].count()}, - index=pd.MultiIndex.from_product( - [dca.index.levels[0], np.arange(1, 13, dtype=int)], - names=['Year', 'Month'])) - madt_ref['Days in Month'] = [ - pd.to_datetime("{0}-{1}-01".format(*idxs)).daysinmonth - for idxs in madt_ref.index] - - assert np.allclose(madt['MADT'], madt_ref['MADT'], rtol=1e-10, - equal_nan=True) - assert np.allclose(madt['Days Available'], - madt_ref['Days Available'], rtol=1e-10, - equal_nan=True) - assert np.allclose(madt['Days in Month'], - madt_ref['Days in Month'], rtol=1e-10, - equal_nan=True) - - # None of the sample data are for leap years. - assert (madt['Days in Month'].sum() // - len(dca.index.levels[0])) == 365 - - def test_get_aadt_py_from_madt(self, ptc_oneyear, ptc_multiyear): - for ptc in (ptc_oneyear, ptc_multiyear): - madt = self.dvc.get_madt( - self.dvc.preprocess_daily_counts(ptc.data)) - aadt = self.dvc.get_aadt_py_from_madt(madt, ptc.perm_years) - - madt_py = madt.loc[ptc.perm_years, :].copy() - madt_py['Weighted MADT'] = (madt_py['MADT'] * - madt_py['Days in Month']) - madtg = madt_py.groupby('Year') - aadt_ref = (madtg['Weighted MADT'].sum() / - madtg['Days in Month'].sum()) - - assert np.allclose(aadt['AADT'], aadt_ref, rtol=1e-10) - - def test_get_ratios_py(self, ptc_oneyear, ptc_multiyear): - for ptc in (ptc_oneyear, ptc_multiyear): - dca = self.dvc.preprocess_daily_counts(ptc.data) - madt = self.dvc.get_madt(dca) - aadt = self.dvc.get_aadt_py_from_madt(madt, ptc.perm_years) - dom_ijd, d_ijd, n_avail_days = ( - self.dvc.get_ratios_py(dca, madt, aadt, ptc.perm_years)) - - dc_dom = (dca.loc[ptc.perm_years] - .groupby(['Year', 'Month', 'Day of Week'])) - domadt = (dc_dom['Daily Count'].mean() - .unstack(level=-1, fill_value=np.nan)) - n_avail_days_ref = (dc_dom['Daily Count'].count() - .unstack(level=-1, fill_value=np.nan)) - - assert np.allclose(n_avail_days, n_avail_days_ref, - rtol=1e-10, equal_nan=True) - - # Test if we can recover MADT from `domadt` and `dom_ijd` - madt_pym = np.repeat(madt.loc[ptc.perm_years, 'MADT'] - .values[:, np.newaxis], 7, axis=1) - madt_pym_est = (domadt * dom_ijd).values - assert np.allclose(madt_pym_est[~np.isnan(madt_pym_est)], - madt_pym[~np.isnan(madt_pym_est)], - rtol=1e-10) - - # Test if we can recover AADT from `domadt` and `d_ijd`. - aadt_pym = np.repeat(aadt['AADT'] - .values[:, np.newaxis], 7 * 12, axis=1) - aadt_pym_est = (domadt * d_ijd).unstack(level=-1).values - assert np.allclose(aadt_pym_est[~np.isnan(aadt_pym_est)], - aadt_pym[~np.isnan(aadt_pym_est)], - rtol=1e-10) + def test_preprocess_daily_counts(self, ptc_sample): + dca = self.dvc.preprocess_daily_counts(ptc_sample.data) + assert np.array_equal(dca['Month'], ptc_sample.data['Date'].dt.month) + assert np.array_equal(dca['Day of Week'], + ptc_sample.data['Date'].dt.dayofweek) + + def test_get_madt(self, ptc_sample): + dca = self.dvc.preprocess_daily_counts(ptc_sample.data) + madt = self.dvc.get_madt(dca) + + madt_ref = pd.DataFrame({ + 'MADT': dca.groupby(['Year', 'Month'])['Daily Count'].mean(), + 'Days Available': dca.groupby( + ['Year', 'Month'])['Daily Count'].count()}, + index=pd.MultiIndex.from_product( + [dca.index.levels[0], np.arange(1, 13, dtype=int)], + names=['Year', 'Month'])) + madt_ref['Days in Month'] = [ + pd.to_datetime("{0}-{1}-01".format(*idxs)).daysinmonth + for idxs in madt_ref.index] + + tols = {'rtol': 1e-10, 'equal_nan': True} + assert np.allclose(madt['MADT'], madt_ref['MADT'], **tols) + assert np.allclose(madt['Days Available'], + madt_ref['Days Available'], **tols) + assert np.allclose(madt['Days in Month'], + madt_ref['Days in Month'], **tols) + + # None of the sample data are for leap years. + assert (madt['Days in Month'].sum() // + len(dca.index.levels[0])) == 365 + assert (madt['Days in Month'].sum() % len(dca.index.levels[0])) == ( + 1 if 2012 in dca.index.levels[0] else 0) + + def test_get_aadt_py_from_madt(self, ptc_sample): + madt = self.dvc.get_madt( + self.dvc.preprocess_daily_counts(ptc_sample.data)) + aadt = self.dvc.get_aadt_py_from_madt(madt, ptc_sample.perm_years) + + madt_py = madt.loc[ptc_sample.perm_years, :].copy() + madt_py['Weighted MADT'] = (madt_py['MADT'] * + madt_py['Days in Month']) + madtg = madt_py.groupby('Year') + aadt_ref = (madtg['Weighted MADT'].sum() / + madtg['Days in Month'].sum()) + + assert np.allclose(aadt['AADT'], aadt_ref, rtol=1e-10) + + def test_get_ratios_py(self, ptc_sample): + dca = self.dvc.preprocess_daily_counts(ptc_sample.data) + madt = self.dvc.get_madt(dca) + aadt = self.dvc.get_aadt_py_from_madt(madt, ptc_sample.perm_years) + dom_ijd, d_ijd, n_avail_days = ( + self.dvc.get_ratios_py(dca, madt, aadt, ptc_sample.perm_years)) + + dc_dom = (dca.loc[ptc_sample.perm_years] + .groupby(['Year', 'Month', 'Day of Week'])) + domadt = (dc_dom['Daily Count'].mean() + .unstack(level=-1, fill_value=np.nan)) + n_avail_days_ref = (dc_dom['Daily Count'].count() + .unstack(level=-1, fill_value=np.nan)) + + assert np.allclose(n_avail_days, n_avail_days_ref, + rtol=1e-10, equal_nan=True) + + # Test if we can recover MADT from `domadt` and `dom_ijd`. + madt_pym = np.repeat(madt.loc[ptc_sample.perm_years, 'MADT'] + .values[:, np.newaxis], 7, axis=1) + madt_pym_est = (domadt * dom_ijd).values + # madt_pym naturally has no NaNs, while madt_pym_est does, so only + # compare non-NaN values. + assert np.allclose(madt_pym_est[~np.isnan(madt_pym_est)], + madt_pym[~np.isnan(madt_pym_est)], + rtol=1e-10) + + # Test if we can recover AADT from `domadt` and `d_ijd`. + aadt_pym = np.repeat(aadt['AADT'] + .values[:, np.newaxis], 7 * 12, axis=1) + aadt_pym_est = (domadt * d_ijd).unstack(level=-1).values + # aadt_pym naturally has no NaNs, while aadt_pym_est does, so only + # compare non-NaN values. + assert np.allclose(aadt_pym_est[~np.isnan(aadt_pym_est)], + aadt_pym[~np.isnan(aadt_pym_est)], + rtol=1e-10) class TestDerivedValsStandard: @@ -119,17 +142,14 @@ class TestDerivedValsStandard: def setup(self): self.dvc = dv.DerivedValsStandard() - def test_get_derived_vals(self, sample_counts, cfgcm_test): - ptc_oneyear = get_single_ptc(sample_counts, cfgcm_test, -890) - ptc_multiyear = get_single_ptc(sample_counts, cfgcm_test, -104870) - - for ptc in (ptc_oneyear, ptc_multiyear): - self.dvc.get_derived_vals(ptc) - assert 'MADT' in ptc.adts.keys() - assert 'AADT' in ptc.adts.keys() - assert 'DoM_ijd' in ptc.ratios.keys() - assert 'D_ijd' in ptc.ratios.keys() - assert 'N_avail_days' in ptc.ratios.keys() + @pytest.mark.parametrize('count_id', [-890, -104870]) + def test_get_derived_vals(self, sample_counts, cfgcm_test, count_id): + ptc = get_single_ptc(sample_counts, cfgcm_test, count_id) + + self.dvc.get_derived_vals(ptc) + assert sorted(list(ptc.adts.keys())) == ['AADT', 'MADT'] + assert sorted(list(ptc.ratios.keys())) == [ + 'D_ijd', 'DoM_ijd', 'N_avail_days'] def test_imputer(self, sample_counts, cfgcm_test): pass diff --git a/traffic_prophet/countmatch/tests/test_growthfactor.py b/traffic_prophet/countmatch/tests/test_growthfactor.py index 6f3d2fc..2636acd 100644 --- a/traffic_prophet/countmatch/tests/test_growthfactor.py +++ b/traffic_prophet/countmatch/tests/test_growthfactor.py @@ -8,29 +8,56 @@ from .. import growthfactor as gf -def get_single_ptc(sample_counts, cfgcm_test, count_id): - pcpp = pc.PermCountProcessor(None, None, cfg=cfgcm_test) - perm_years = pcpp.partition_years(sample_counts.counts[count_id]) - ptc = pc.PermCount.from_count_object(sample_counts.counts[count_id], +def get_single_ptc(counts, cfgcm, count_id): + pcpp = pc.PermCountProcessor(None, None, cfg=cfgcm) + perm_years = pcpp.partition_years(counts.counts[count_id]) + ptc = pc.PermCount.from_count_object(counts.counts[count_id], perm_years) dvs = dv.DerivedVals('Standard') dvs.get_derived_vals(ptc) return ptc -@pytest.fixture(scope='module') -def ptc_oneyear(sample_counts, cfgcm_test): - return get_single_ptc(sample_counts, cfgcm_test, -890) +class TestGFRegistrarGrowthFactor: + """Tests GFRegistrar and GrowthFactor.""" + def test_gfregistrar(self): -@pytest.fixture(scope='module') -def ptc_multiyear(sample_counts, cfgcm_test): - return get_single_ptc(sample_counts, cfgcm_test, -104870) + # Test successful initialization of GrowthFactor subclass. + class GrowthFactorCompositeTest(gf.GrowthFactorComposite): + _fit_type = 'Testing' + + assert gf.GF_REGISTRY['Testing'] is GrowthFactorCompositeTest + gf_instance = gf.GrowthFactor('Testing') + assert gf_instance._fit_type == 'Testing' + + # Pop the dummy class, in case we test twice. + gf.GFRegistrar._registry.pop('Testing', None) + + # Test repeated `_dv_type` error handling. + with pytest.raises(ValueError) as excinfo: + class GrowthFactorCompositeBad1(gf.GrowthFactorComposite): + pass + assert "already registered in" in str(excinfo.value) + + # Test missing `_dv_type` error handling. + with pytest.raises(ValueError) as excinfo: + class GrowthFactorCompositeBad2(gf.GrowthFactorBase): + pass + assert "must define a" in str(excinfo.value) class TestGrowthFactorBase: """Test growth factor base class.""" + @pytest.fixture() + def ptc_oneyear(self, sample_counts, cfgcm_test): + return get_single_ptc(sample_counts, cfgcm_test, -890) + + @pytest.fixture() + def ptc_multiyear(self, sample_counts, cfgcm_test): + return get_single_ptc(sample_counts, cfgcm_test, -104870) + def setup(self): self.gfb = gf.GrowthFactorBase() @@ -45,7 +72,7 @@ def test_get_aadt(self, ptc_oneyear, ptc_multiyear): aadt['AADT'].values) def test_get_wadt_py(self, ptc_oneyear, ptc_multiyear): - # For single year PTC, confirm WADT values for individual weeks. + # Confirm WADT values for individual weeks. wadt_oy = self.gfb.get_wadt_py(ptc_oneyear) wadt_jun14 = (ptc_oneyear.data .loc[(2010, 165):(2010, 171), 'Daily Count'].mean()) @@ -56,7 +83,6 @@ def test_get_wadt_py(self, ptc_oneyear, ptc_multiyear): assert np.isclose( wadt_oy.loc[wadt_oy['Week'] == 48, 'WADT'].values[0], wadt_nov29) - # For multiyear PTC, confirm we can reproduce data frame. wadt_my = self.gfb.get_wadt_py(ptc_multiyear) wadt_apr26_2010 = (ptc_multiyear.data diff --git a/traffic_prophet/countmatch/tests/test_permcount.py b/traffic_prophet/countmatch/tests/test_permcount.py index 95c8471..e214800 100644 --- a/traffic_prophet/countmatch/tests/test_permcount.py +++ b/traffic_prophet/countmatch/tests/test_permcount.py @@ -26,6 +26,10 @@ def test_permcount(self, sample_counts): assert ptc.is_permanent assert ptc.perm_years == [2010, 2012] + with pytest.raises(AttributeError) as excinfo: + ptc.growth_factor + assert "PTC has not had its growth factor fit!" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: ptc = pc.PermCount.from_count_object( sample_counts.counts[-104870], []) @@ -37,7 +41,7 @@ class TestPermCountProcessor: def test_setup(self, pcproc): assert isinstance(pcproc.dvc, dv.DerivedValsStandard) assert isinstance(pcproc.gfc, gf.GrowthFactorComposite) - # We passed a custom cfgcm with no excluded IDs. + # We passed a custom cfgcm with one excluded ID. assert pcproc.excluded_ids == [-446378, ] def test_partition(self, pcproc, sample_counts): diff --git a/traffic_prophet/countmatch/tests/test_reader.py b/traffic_prophet/countmatch/tests/test_reader.py index db363ab..e106c6e 100644 --- a/traffic_prophet/countmatch/tests/test_reader.py +++ b/traffic_prophet/countmatch/tests/test_reader.py @@ -120,8 +120,8 @@ def test_regularize_timeseries(self, rdr, counts): def test_preprocess_count_data(self, rdr, counts): ref = counts[7] - rd = ref.copy() # Do a deep copy because preprocess_count_data alters its arguments. + rd = ref.copy() rd['data'] = rd['data'].copy() rd = rdr.preprocess_count_data(rd)