diff --git a/src/ehrdata/io/omop/_queries.py b/src/ehrdata/io/omop/_queries.py index 6479622..552bd3c 100644 --- a/src/ehrdata/io/omop/_queries.py +++ b/src/ehrdata/io/omop/_queries.py @@ -115,7 +115,7 @@ def _generate_value_query(data_table: str, data_field_to_keep: Sequence, aggrega return is_present_query + value_query -def _time_interval_table( +def _write_long_time_interval_table( backend_handle: duckdb.duckdb.DuckDBPyConnection, time_defining_table: str, data_table: str, @@ -125,8 +125,7 @@ def _time_interval_table( aggregation_strategy: str, data_field_to_keep: Sequence[str] | str, keep_date: str = "", - return_as_df: bool = False, -) -> pd.DataFrame | None: +) -> None: if isinstance(data_field_to_keep, str): data_field_to_keep = [data_field_to_keep] @@ -219,127 +218,3 @@ def _time_interval_table( WHERE long_person_timestamp_feature_value.person_id = RP.person_id; """ backend_handle.execute(add_person_range_index_query) - - if return_as_df: - return backend_handle.execute("SELECT * FROM long_person_timestamp_feature_value").df() - else: - return None - - -# def _get_time_interval_table( -# backend_handle: duckdb.duckdb.DuckDBPyConnection, -# time_defining_table: str, -# data_table: str, -# interval_length_number: int, -# interval_length_unit: str, -# num_intervals: int, -# aggregation_strategy: str, -# data_field_to_keep: Sequence[str] | str, -# keep_date: str = "", -# ): -# return backend_handle.execute("SELECT * FROM long_person_timestamp_feature_value").df() - - -# def _time_interval_table_for_dataloader( -# backend_handle: duckdb.duckdb.DuckDBPyConnection, -# time_defining_table: str, -# data_table: str, -# interval_length_number: int, -# interval_length_unit: str, -# num_intervals: int, -# aggregation_strategy: str, -# data_field_to_keep: Sequence[str] | str, -# keep_date: str = "", -# ): -# if isinstance(data_field_to_keep, str): -# data_field_to_keep = [data_field_to_keep] - -# if keep_date == "": -# keep_date = "timepoint" - -# timedeltas_dataframe = _generate_timedeltas(interval_length_number, interval_length_unit, num_intervals) - -# _write_timedeltas_to_db( -# backend_handle, -# timedeltas_dataframe, -# ) - -# # multi-step query -# # 1. Create person_time_defining_table, which matches the one created for obs. Needs to contain the person_id, and the start date in particular. -# # 2. Create person_data_table (data_table is typically measurement), which contains the cross product of person_id and the distinct concept_id s. -# # 3. Create long_format_backbone, which is the left join of person_time_defining_table and person_data_table. -# # 4. Create long_format_intervals, which is the cross product of long_format_backbone and timedeltas. This table contains most notably the person_id, the concept_id, the interval start and end dates. -# # 5. Create the final table, which is the join with the data_table (typically measurement); each measurement is assigned to its person_id, its concept_id, and the interval it fits into. -# prepare_alias_query = f""" -# CREATE TABLE long_person_timestamp_feature_value AS \ -# WITH person_time_defining_table AS ( \ -# SELECT person.person_id as person_id, {DATA_TABLE_DATE_KEYS["start"][time_defining_table]} as start_date, {DATA_TABLE_DATE_KEYS["end"][time_defining_table]} as end_date \ -# FROM person \ -# JOIN {time_defining_table} ON person.person_id = {time_defining_table}.{TIME_DEFINING_TABLE_SUBJECT_KEY[time_defining_table]} \ -# WHERE visit_concept_id = 262 \ -# ), \ -# person_data_table AS( \ -# WITH distinct_data_table_concept_ids AS ( \ -# SELECT DISTINCT {DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id -# FROM {data_table} \ -# ) -# SELECT person.person_id, {DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id as data_table_concept_id \ -# FROM person \ -# CROSS JOIN distinct_data_table_concept_ids \ -# ), \ -# long_format_backbone as ( \ -# SELECT person_time_defining_table.person_id, data_table_concept_id, start_date, end_date \ -# FROM person_time_defining_table \ -# LEFT JOIN person_data_table USING(person_id)\ -# ), \ -# long_format_intervals as ( \ -# SELECT person_id, data_table_concept_id, interval_step, start_date, start_date + interval_start_offset as interval_start, start_date + interval_end_offset as interval_end \ -# FROM long_format_backbone \ -# CROSS JOIN timedeltas \ -# ), \ -# data_table_with_presence_indicator as( \ -# SELECT *, 1 as is_present \ -# FROM {data_table} \ -# ) \ -# """ - -# if keep_date in ["timepoint", "start", "end"]: -# select_query = f""" -# SELECT lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end, {_generate_value_query("data_table_with_presence_indicator", data_field_to_keep, AGGREGATION_STRATEGY_KEY[aggregation_strategy])} \ -# FROM long_format_intervals as lfi \ -# LEFT JOIN data_table_with_presence_indicator ON lfi.person_id = data_table_with_presence_indicator.person_id AND lfi.data_table_concept_id = data_table_with_presence_indicator.{DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id AND data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS[keep_date][data_table]} BETWEEN lfi.interval_start AND lfi.interval_end \ -# GROUP BY lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end -# """ - -# elif keep_date == "interval": -# select_query = f""" -# SELECT lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end, {_generate_value_query("data_table_with_presence_indicator", data_field_to_keep, AGGREGATION_STRATEGY_KEY[aggregation_strategy])} \ -# FROM long_format_intervals as lfi \ -# LEFT JOIN data_table_with_presence_indicator ON lfi.person_id = data_table_with_presence_indicator.person_id \ -# AND lfi.data_table_concept_id = data_table_with_presence_indicator.{DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id \ -# AND (data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["start"][data_table]} BETWEEN lfi.interval_start AND lfi.interval_end \ -# OR data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["end"][data_table]} BETWEEN lfi.interval_start AND lfi.interval_end \ -# OR (data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["start"][data_table]} < lfi.interval_start AND data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["end"][data_table]} > lfi.interval_end)) \ -# GROUP BY lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end -# """ - -# query = prepare_alias_query + select_query -# backend_handle.execute("DROP TABLE IF EXISTS long_person_timestamp_feature_value") -# backend_handle.execute(query) -# add_person_range_index_query = """ -# ALTER TABLE long_person_timestamp_feature_value -# ADD COLUMN person_index INTEGER; - -# WITH RankedPersons AS ( -# SELECT person_id, -# ROW_NUMBER() OVER (ORDER BY person_id) - 1 AS idx -# FROM (SELECT DISTINCT person_id FROM long_person_timestamp_feature_value) AS unique_persons -# ) -# UPDATE long_person_timestamp_feature_value -# SET person_index = RP.idx -# FROM RankedPersons RP -# WHERE long_person_timestamp_feature_value.person_id = RP.person_id; -# """ -# backend_handle.execute(add_person_range_index_query) - -# return None diff --git a/src/ehrdata/io/omop/omop.py b/src/ehrdata/io/omop/omop.py index 9ab2c5a..452bff6 100644 --- a/src/ehrdata/io/omop/omop.py +++ b/src/ehrdata/io/omop/omop.py @@ -32,7 +32,7 @@ _check_valid_observation_table, _check_valid_variable_data_tables, ) -from ehrdata.io.omop._queries import _time_interval_table +from ehrdata.io.omop._queries import _write_long_time_interval_table DOWNLOAD_VERIFICATION_TAG = "download_verification_tag" @@ -345,30 +345,21 @@ def setup_variables( logging.warning(f"No data found in {data_tables[0]}. Returning edata without additional variables.") return edata - # TODO: if instantiate_tensor - ds = ( - _time_interval_table( - backend_handle=backend_handle, - time_defining_table=time_defining_table, - data_table=data_tables[0], - data_field_to_keep=data_field_to_keep, - interval_length_number=interval_length_number, - interval_length_unit=interval_length_unit, - num_intervals=num_intervals, - aggregation_strategy=aggregation_strategy, - return_as_df=True, - ) - .set_index(["person_id", "data_table_concept_id", "interval_step"]) - .to_xarray() + _write_long_time_interval_table( + backend_handle=backend_handle, + time_defining_table=time_defining_table, + data_table=data_tables[0], + data_field_to_keep=data_field_to_keep, + interval_length_number=interval_length_number, + interval_length_unit=interval_length_unit, + num_intervals=num_intervals, + aggregation_strategy=aggregation_strategy, ) - # TODO: if instantiate_tensor! rdbms backed, make ds independent but build on long table _check_one_unit_per_feature(backend_handle) - # TODO ignore? go with more vanilla omop style. _check_one_unit_per_feature(ds, unit_key="unit_source_value") - unit_report = _create_feature_unit_concept_id_report(backend_handle) - var = ds["data_table_concept_id"].to_dataframe() + var = backend_handle.execute("SELECT DISTINCT data_table_concept_id FROM long_person_timestamp_feature_value").df() if enrich_var_with_feature_info or enrich_var_with_unit_info: concepts = backend_handle.sql("SELECT * FROM concept").df() @@ -398,9 +389,19 @@ def setup_variables( suffixes=("", "_unit"), ) - t = ds["interval_step"].to_dataframe() + t = pd.DataFrame({"interval_step": np.arange(num_intervals)}) - edata = EHRData(r=ds[data_field_to_keep[0]].values, obs=edata.obs, var=var, uns=edata.uns, t=t) + if instantiate_tensor: + ds = ( + (backend_handle.execute("SELECT * FROM long_person_timestamp_feature_value").df()) + .set_index(["person_id", "data_table_concept_id", "interval_step"]) + .to_xarray() + ) + + else: + ds = None + + edata = EHRData(r=ds[data_field_to_keep[0]].values if ds else None, obs=edata.obs, var=var, uns=edata.uns, t=t) edata.uns[f"unit_report_{data_tables[0]}"] = unit_report return edata @@ -420,6 +421,7 @@ def setup_interval_variables( enrich_var_with_feature_info: bool = False, enrich_var_with_unit_info: bool = False, keep_date: Literal["start", "end", "interval"] = "start", + instantiate_tensor: bool = True, ): """Setup the interval variables @@ -453,6 +455,8 @@ def setup_interval_variables( Whether to enrich the var table with feature information. If a concept_id is not found in the concept table, the feature information will be NaN. date_type Whether to keep the start or end date, or the interval span. + instantiate_tensor + Whether to instantiate the tensor into the .r field of the EHRData object. Returns ------- @@ -483,24 +487,26 @@ def setup_interval_variables( logging.warning(f"No data in {data_tables}.") return edata + _write_long_time_interval_table( + backend_handle=backend_handle, + time_defining_table=time_defining_table, + data_table=data_tables[0], + data_field_to_keep=data_field_to_keep, + interval_length_number=interval_length_number, + interval_length_unit=interval_length_unit, + num_intervals=num_intervals, + aggregation_strategy=aggregation_strategy, + keep_date=keep_date, + ) + ds = ( - _time_interval_table( - backend_handle=backend_handle, - time_defining_table=time_defining_table, - data_table=data_tables[0], - data_field_to_keep=data_field_to_keep, - interval_length_number=interval_length_number, - interval_length_unit=interval_length_unit, - num_intervals=num_intervals, - aggregation_strategy=aggregation_strategy, - keep_date=keep_date, - return_as_df=True, - ) + backend_handle.execute("SELECT * FROM long_person_timestamp_feature_value") + .df() .set_index(["person_id", "data_table_concept_id", "interval_step"]) .to_xarray() ) - var = ds["data_table_concept_id"].to_dataframe() + var = backend_handle.execute("SELECT DISTINCT data_table_concept_id FROM long_person_timestamp_feature_value").df() if enrich_var_with_feature_info or enrich_var_with_unit_info: concepts = backend_handle.sql("SELECT * FROM concept").df() @@ -509,7 +515,7 @@ def setup_interval_variables( if enrich_var_with_feature_info: var = pd.merge(var, concepts, how="left", left_index=True, right_on="concept_id") - t = ds["interval_step"].to_dataframe() + t = pd.DataFrame({"interval_step": np.arange(num_intervals)}) edata = EHRData(r=ds[data_field_to_keep[0]].values, obs=edata.obs, var=var, uns=edata.uns, t=t)