Skip to content

Commit

Permalink
#32 address PR comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles Zhu committed Jan 21, 2020
1 parent e2018db commit e215643
Show file tree
Hide file tree
Showing 10 changed files with 323 additions and 138 deletions.
2 changes: 1 addition & 1 deletion traffic_prophet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# For countmatch
cm = {
'verbose': True,
'verbose': False,
'min_count': 96,
'min_counts_in_day': 24,
'min_permanent_months': 12,
Expand Down
102 changes: 93 additions & 9 deletions traffic_prophet/countmatch/derivedvals.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DVRegistrar(type):

def __init__(cls, name, bases, dct):

# Register GrowthFactorBase subclass if `_dv_type` not already taken.
# Register DerivedVal subclass if `_dv_type` not already taken.
if name not in ('DerivedValsBase', 'DerivedVals'):
if not hasattr(cls, "_dv_type"):
raise ValueError("must define a `_dv_type`.")
Expand Down Expand Up @@ -57,13 +57,40 @@ class DerivedValsBase(metaclass=DVRegistrar):

@staticmethod
def preprocess_daily_counts(dc):
"""Preprocess daily counts by adding month and day of week.
Parameters
---------------
dc : pandas.DataFrame
Daily counts.
Returns
-------
dca : pandas.DataFrame
Copy of `dc` with 'Month' and 'Day of Week' columns.
"""
# Ensure dc remains unchanged by the method.
dca = dc.copy()
dca['Month'] = dca['Date'].dt.month
dca['Day of Week'] = dca['Date'].dt.dayofweek
return dca

def get_madt(self, dca):
"""Get mean average daily traffic (MADT) from processed daily counts.
Parameters
---------------
dca : pandas.DataFrame
Daily counts, with 'Month' and 'Day of Week' columns from
`preprocess_daily_counts`.
Returns
-------
madt : pandas.DataFrame
MADT table, with 'Days Available' and 'Days in Month' columns.
"""
dc_m = dca.groupby(['Year', 'Month'])
madt = pd.DataFrame({
'MADT': dc_m['Daily Count'].mean(),
Expand All @@ -73,9 +100,10 @@ def get_madt(self, dca):
names=['Year', 'Month'])
)

# Loop to record number of days in month.
# Loop to record the number of days in each month. February requires
# recalculating in case of leap years.
days_in_month = []
for year in dca.index.levels[0]:
for year in madt.index.levels[0]:
cdays = self._months_of_year.copy()
cdays[1] = pd.to_datetime('{0:d}-02-01'.format(year)).daysinmonth
days_in_month += cdays
Expand All @@ -85,17 +113,56 @@ def get_madt(self, dca):

@staticmethod
def get_aadt_py_from_madt(madt, perm_years):
# Weighted average for AADT.
"""Annual average daily traffic (AADT) from an MADT weighted average.
Parameters
---------------
madt : pandas.DataFrame
MADT, with 'Days in Month' column as from `get_madt`.
perm_years : list
List of permanent count years for location; obtained from
PermCount.perm_years.
Returns
-------
aadt : pandas.DataFrame
AADT table.
"""
madt_py = madt.loc[perm_years, :]
monthly_total_traffic = madt_py['MADT'] * madt_py['Days in Month']
return pd.DataFrame(
{'AADT': (monthly_total_traffic.groupby('Year').sum() /
madt_py.groupby('Year')['Days in Month'].sum())})

@staticmethod
def get_ratios_py(dca, madt, aadt, perm_years):
dca_py = dca.loc[perm_years].copy()
dc_dom = dca_py.groupby(['Year', 'Month', 'Day of Week'])
def get_ratios_py(dca, madt, aadt_py, perm_years):
"""Ratios between MADT and AADT and day-of-month average daily traffic.
Parameters
---------------
dca : pandas.DataFrame
Daily counts, with 'Month' and 'Day of Week' columns from
`preprocess_daily_counts`.
madt : pandas.DataFrame
MADT, with 'Days in Month' column as from `get_madt`.
aadt_py : pandas.DataFrame
AADT for permanent years, as from `get_aadt_py_from_madt`.
perm_years : list
List of permanent count years for location; obtained from
PermCount.perm_years.
Returns
-------
dom_ijd : pandas.DataFrame
Ratio between MADT and day-of-month ADT.
d_ijd : pandas.DataFrame
Ratio between AADT and day-of-month ADT.
n_avail_days : pandas.DataFrame
Number of days used to calculate day-of-month ADT.
"""
dc_dom = dca.loc[perm_years].groupby(['Year', 'Month', 'Day of Week'])
# Multi-index levels retain all values from dca even after using loc,
# so can only use `perm_years`.
ymd_index = pd.MultiIndex.from_product(
Expand All @@ -112,8 +179,9 @@ def get_ratios_py(dca, madt, aadt, perm_years):
# broadcasting trick.)
dom_ijd = (madt['MADT'].loc[perm_years, :].values[:, np.newaxis] /
domadt)
# Determine day-to-year conversion factor D_ijd.
d_ijd = (aadt['AADT'].values[:, np.newaxis] /
# Determine day-to-year conversion factor D_ijd. (Uses broadcasting
# and pivoting pandas columns.)
d_ijd = (aadt_py['AADT'].values[:, np.newaxis] /
domadt.unstack(level=-1)).stack()

return dom_ijd, d_ijd, n_avail_days
Expand All @@ -130,6 +198,16 @@ def __init__(self, impute_ratios=False):
self._impute_ratios = impute_ratios

def get_derived_vals(self, ptc):
"""Get derived values, including ADTs and ratios between them.
Depending on settings, will also impute missing values.
Parameters
----------
ptc : permcount.PermCount
Permanent count instance.
"""
dca = self.preprocess_daily_counts(ptc.data)
madt = self.get_madt(dca)
aadt = self.get_aadt_py_from_madt(madt, ptc.perm_years)
Expand All @@ -139,3 +217,9 @@ def get_derived_vals(self, ptc):
ptc.adts = {'MADT': madt, 'AADT': aadt}
ptc.ratios = {'DoM_ijd': dom_ijd, 'D_ijd': d_ijd,
'N_avail_days': n_avail_days}

if self._impute_ratios:
self.impute_ratios(ptc)

def impute_ratios(self, ptc):
raise NotImplementedError
56 changes: 52 additions & 4 deletions traffic_prophet/countmatch/growthfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,39 @@ class GrowthFactorBase(metaclass=GFRegistrar):

@staticmethod
def get_aadt(tc):
"""Get AADT (for permanent years)."""
"""Get AADT (for permanent years).
Parameters
----------
tc : permcount.PermCount
Permanent count instance.
Returns
-------
aadt : pandas.DataFrame
Modified AADT table where 'Year' is a column, not an index.
"""
aadt = tc.adts['AADT'].reset_index()
aadt['Year'] = aadt['Year'].astype(float)
return aadt

@staticmethod
def get_wadt_py(tc):
"""Get WADT for all full weeks within permanent years."""
"""Get WADT for all full weeks within permanent years.
Parameters
----------
tc : permcount.PermCount
Permanent count instance.
Returns
-------
wadt : pandas.DataFrame
Weekly average daily traffic, and 'Time' in weeks from the start of
the first permanent count year (in `perm_years`).
"""
cdata = tc.data.loc[tc.perm_years, :].copy()

# Corrective measure, as dt.week returns the "week ordinal". See
Expand Down Expand Up @@ -122,6 +147,14 @@ def exponential_rate_fit(year, aadt, ref_vals):
return model.fit()

def fit_growth(self, tc):
"""Fit an exponential growth module to AADT vs. year.
Parameters
----------
tc : permcount.PermCount
Permanent count instance.
"""
# Process year vs. AADT data.
aadt = self.get_aadt(tc)

Expand Down Expand Up @@ -174,7 +207,14 @@ def linear_rate_fit(week, wadt):
return model.fit()

def fit_growth(self, tc):
# Process week vs. weekly averaged ADT.
"""Fit a linear growth model to WADT vs. time in weeks.
Parameters
----------
tc : permcount.PermCount
Permanent count instance.
"""
wadt_py = self.get_wadt_py(tc)
fit_results = self.linear_rate_fit(wadt_py['Time'].values,
wadt_py['WADT'].values)
Expand All @@ -183,7 +223,7 @@ def fit_growth(self, tc):
# AADTs can only be calculated for perm_years, so consistent with
# `get_wadt_py`.
growth_factor = 1. + (fit_results.params[1] * 52. /
tc.adts['AADT']['AADT'].values[0])
tc.adts['AADT'].iat[0, 0])

tc._growth_fit = {'fit_type': 'Linear',
'fit_results': fit_results,
Expand All @@ -201,6 +241,14 @@ def __init__(self):
self.wadt_lin = GrowthFactorWADTLin()

def fit_growth(self, tc):
"""Fit a linear model for single years, exponential for multiple years.
Parameters
----------
tc : permcount.PermCount
Permanent count instance.
"""
if tc.adts['AADT'].shape[0] > 1:
self.aadt_exp.fit_growth(tc)
else:
Expand Down
7 changes: 4 additions & 3 deletions traffic_prophet/countmatch/permcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ def __init__(self, count_id, centreline_id, direction, data,

@property
def perm_years(self):
"""List of years where the count is permanent."""
return self._perm_years

@property
def growth_factor(self):
"""Multiplicative year-on-year growth factor."""
if self._growth_fit is None:
raise AttributeError("PTC has not had its growth factor fit!")
return self._growth_fit['growth_factor']
Expand All @@ -63,7 +65,7 @@ class PermCountProcessor:
dv_calc : DerivedVal instance
For imputation and derived properties.
gf_calc : GrowthFactor instance
For estimating growth factor.
For estimating the growth factor.
cfg : dict, optional
Configuration settings. Default: `config.cm`.
"""
Expand All @@ -72,8 +74,7 @@ def __init__(self, dv_calc, gf_calc, cfg=cfg.cm):
self.dvc = dv_calc
self.gfc = gf_calc
self.cfg = cfg
# Obtain a list of count_ids that (according to TEPs-I) should not be
# PTCs because they reduce the accuracy of CountMatch.
# Obtain a list of count_ids that should be excluded from being PTCs.
self.excluded_ids = (self.cfg['exclude_ptc_pos'] +
[-id for id in self.cfg['exclude_ptc_neg']])
self._disable_tqdm = not self.cfg['verbose']
Expand Down
2 changes: 1 addition & 1 deletion traffic_prophet/countmatch/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""CountMatch test suite preprocessing.
Fixtures used by multiple files in this and subdirectories must be placed here.
Fixtures used by multiple files in this and subdirectories are placed here.
See https://docs.pytest.org/en/latest/fixture.html#conftest-py-sharing-fixture-functions
for more.
Expand Down
4 changes: 3 additions & 1 deletion traffic_prophet/countmatch/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import operator

from .. import base


Expand All @@ -7,6 +9,6 @@ def test_count(self):
count = base.Count('test', 1, -1., None)
assert count.count_id == 'test'
assert count.centreline_id == 1
assert count.direction == -1
assert operator.index(count.direction) == -1
assert count.data is None
assert not count.is_permanent
Loading

0 comments on commit e215643

Please sign in to comment.