Skip to content

Commit

Permalink
added cohort constraint to valueset levels (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
hyi authored Mar 1, 2024
1 parent c7483f5 commit 8b62d13
Showing 1 changed file with 75 additions and 66 deletions.
141 changes: 75 additions & 66 deletions icees_api/features/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,47 +38,37 @@ def get_digest(*args):
return c.digest()


def op_dict(k, v, table_):
try:
value = table_.c[k]
except KeyError:
raise HTTPException(status_code=400, detail=f"No feature named '{k}'")
# python_type = value.type.python_type
# if v["operator"] == "in":
# values = v["values"]
# elif v["operator"] == "between":
# values = [v["value_a"], v["value_b"]]
# else:
# values = [v["value"]]
# options = features_dict[table_name][k].get("enum", None)
# for value in values:
# if not isinstance(value, python_type):
# raise HTTPException(
# status_code=400,
# detail="'{feature}' should be of type {type}, but {value} is not".format(
# value=value,
# feature=k,
# type=python_type,
# )
# )
# if options is not None and value not in options:
# raise HTTPException(
# status_code=400,
# detail="{value} is not in {options}".format(
# value=value,
# options=options
# )
# )

def op_dict(k, v, table_=None):
if table_ is not None:
try:
value = table_.c[k]
except KeyError:
raise HTTPException(status_code=400, detail=f"No feature named '{k}'")
else:
k_op_val_dict = get_level_operator_and_value(k)
value = simplify_value(k_op_val_dict['value'], v["operator"])
if isinstance(value, int):
# for < and > operators in value set definition, need to change the value to be <= or >= equivalent
# in order to compare the value against the same variable cohort definition operator
# For example, >9 should be equivalent to >=10, and if cohort definition is <10, and value set definition
# is >9, >9 rows should be filtered out by the cohort constraint
if k_op_val_dict['operator'] == '<':
value = value - 1
if k_op_val_dict['operator'] == '>':
value = value + 1


# v is a dict with "operator" key; other keys depend on the "operator" value
operations = {
">": lambda: value > v["value"],
"<": lambda: value < v["value"],
">=": lambda: value >= v["value"],
"<=": lambda: value <= v["value"],
"=": lambda: value == v["value"],
"<>": lambda: value != v["value"],
"between": lambda: between(value, v["value_a"], v["value_b"]),
"in": lambda: value.in_(v["values"])
">": lambda: value > simplify_value(v["value"], v["operator"]),
"<": lambda: value < simplify_value(v["value"], v["operator"]),
">=": lambda: value >= simplify_value(v["value"], v["operator"]),
"<=": lambda: value <= simplify_value(v["value"], v["operator"]),
"=": lambda: value == simplify_value(v["value"], v["operator"]),
"<>": lambda: value != simplify_value(v["value"], v["operator"]),
"between": lambda: between(value, simplify_value(v["value_a"], v["operator"]),
simplify_value(v["value_b"], v["operator"])),
"in": lambda: value.in_([simplify_value(val, v["operator"]) for val in v["values"]])
}
return operations[v["operator"]]()

Expand All @@ -87,26 +77,26 @@ def filter_select(s, k, v, table_):
"""Add WHERE clause to selection."""
return s.where(
op_dict(
k, v, table_,
k, v, table_=table_,
)
)


def case_select(table, k, v, table_name=None):
def case_select(table_, k, v):
return func.coalesce(func.sum(case([(
op_dict(
k, v, table,
k, v, table_=table_,
), 1
)], else_=0)), 0)


def case_select2(table, table2, k, v, k2, v2, table_name=None):
def case_select2(table1, table2, k, v, k2, v2):
return func.coalesce(func.sum(case([(and_(
op_dict(
k, v, table,
k, v, table_=table1,
),
op_dict(
k2, v2, table2,
k2, v2, table_=table2,
)
), 1)], else_=0)), 0)

Expand Down Expand Up @@ -848,12 +838,22 @@ def select_feature_count_all_values(
return count


def get_feature_levels(feature, year=None):
def get_feature_levels(feature, year=None, cohort_feat_dict=None):
"""Get feature levels."""
feat_levs = get_value_sets().get(feature, [])
if year and feature == 'year' and int(year) in feat_levs:
# only include the pass-in year in the corresponding year feature level list
feat_levs = [int(year)]
# filter feat_levs by cohort_feat_dict as needed
if cohort_feat_dict:
for k, v in cohort_feat_dict.items():
if k == 'year':
return [yr for yr in feat_levs if op_dict(yr, v)]
elif cohort_feat_dict:
for k, v in cohort_feat_dict.items():
if feature == k:
return [fl for fl in feat_levs if op_dict(fl, v)]

return feat_levs


Expand Down Expand Up @@ -1034,6 +1034,23 @@ def validate_feature_value_in_table_column_for_equal_operator(conn, table_name,
return


def get_level_operator_and_value(input_level):
non_op_idx = 0
if isinstance(input_level, str):
for lev in input_level:
if lev in ['<', '>']:
non_op_idx += 1
else:
break
if non_op_idx == 0:
op = '='
op_val = input_level
else:
op = input_level[:non_op_idx]
op_val = input_level[non_op_idx:]
return {"operator": op, "value": op_val}


def get_operator_and_value(input_levels, feat_name, append_feature_variable=False):
"""
get operator and value from each input level which will be in the format of '>' or '<' followed by a number or
Expand All @@ -1042,23 +1059,11 @@ def get_operator_and_value(input_levels, feat_name, append_feature_variable=Fals
"""
fqs = []
for input_level in input_levels:
non_op_idx = 0
if isinstance(input_level, str):
for lev in input_level:
if lev in ['<', '>']:
non_op_idx += 1
else:
break
if non_op_idx == 0:
op = '='
op_val = input_level
else:
op = input_level[:non_op_idx]
op_val = input_level[non_op_idx:]
op_val_dict = get_level_operator_and_value(input_level)
if append_feature_variable:
fqs.append({feat_name: {"operator": op, "value": op_val}})
fqs.append({feat_name: op_val_dict})
else:
fqs.append({"operator": op, "value": op_val})
fqs.append(op_val_dict)
return fqs


Expand All @@ -1074,23 +1079,26 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
"for computing multivariate associations")

# get feature_constraint list from the first feature variable
feat_constraint_list = get_operator_and_value(get_feature_levels(feature_variables[0], year=year),
feat_constraint_list = get_operator_and_value(get_feature_levels(feature_variables[0], year=year,
cohort_feat_dict=cohort_features),
feature_variables[0], append_feature_variable=True)

index = 1
while index + 2 <= feat_len:
feature_as = [
{
"feature_name": feature_variables[index],
"feature_qualifiers": get_operator_and_value(get_feature_levels(feature_variables[index], year=year),
"feature_qualifiers": get_operator_and_value(get_feature_levels(feature_variables[index], year=year,
cohort_feat_dict=cohort_features),
feature_variables[index])
}
]
feature_bs = [
{
"feature_name": feature_variables[index + 1],
"feature_qualifiers": get_operator_and_value(get_feature_levels(feature_variables[index + 1],
year=year),
year=year,
cohort_feat_dict=cohort_features),
feature_variables[index + 1])
}
]
Expand All @@ -1116,7 +1124,8 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
index += 2

if index < feat_len:
feature_qualifiers = get_operator_and_value(get_feature_levels(feature_variables[index], year=year),
feature_qualifiers = get_operator_and_value(get_feature_levels(feature_variables[index], year=year,
cohort_feat_dict=cohort_features),
feature_variables[index])
more_constraint_list = []
for feature_constraint in feat_constraint_list:
Expand Down

0 comments on commit 8b62d13

Please sign in to comment.