diff --git a/marquette/conf/config.yaml b/marquette/conf/config.yaml index 5694c15..4fb2600 100644 --- a/marquette/conf/config.yaml +++ b/marquette/conf/config.yaml @@ -1,6 +1,6 @@ name: MERIT data_path: /projects/mhpi/data/${name} -zone: 73 +zone: 74 create_edges: buffer: 0.3334 dx: 2000 @@ -19,7 +19,7 @@ create_N: create_TMs: MERIT: save_sparse: True - TM: ${data_path}/zarr/TMs/MERIT_FLOWLINES_${zone} + TM: ${data_path}/zarr/TMs/sparse_MERIT_FLOWLINES_${zone} shp_files: ${data_path}/raw/basins/cat_pfaf_${zone}_MERIT_Hydro_v07_Basins_v01_bugfix1.shp create_streamflow: version: merit_conus_v3.0 diff --git a/marquette/merit/_TM_calculations.py b/marquette/merit/_TM_calculations.py index 8c845cc..537868a 100644 --- a/marquette/merit/_TM_calculations.py +++ b/marquette/merit/_TM_calculations.py @@ -1,6 +1,7 @@ import logging from pathlib import Path +import binsparse import geopandas as gpd import matplotlib.pyplot as plt import numpy as np @@ -8,6 +9,7 @@ import xarray as xr import zarr from omegaconf import DictConfig +from scipy import sparse from tqdm import tqdm log = logging.getLogger(__name__) @@ -56,6 +58,85 @@ def create_HUC_MERIT_TM( xr_dataset.to_zarr(zarr_path, mode="w") +def format_pairs(gage_output: dict): + pairs = [] + for comid, edge_id in zip(gage_output["comid_idx"], gage_output["edge_id_idx"]): + for edge in edge_id: + # Check if upstream is a list (multiple connections) + if isinstance(edge, list): + for _id in edge: + # Replace None with np.NaN for consistency + if _id is None: + _id = np.NaN + pairs.append((comid, _id)) + else: + # Handle single connection (not a list) + if edge is None: + edge = np.NaN + pairs.append((comid, edge)) + + return pairs + + +def create_coo_data(sparse_matrix, root: zarr.Group): + """ + Creates coordinate format (COO) data from river graph output for a specific gage. + + This function processes the river graph data (specifically the 'ds' and 'up' arrays) + to create a list of pairs representing connections in the graph. These pairs are then + stored in a Zarr dataset within a group specific to a gage, identified by 'padded_gage_id'. + + Parameters: + gage_output: The output from a river graph traversal, containing 'ds' and 'up' keys. + padded_gage_id (str): The identifier for the gage, used to create a specific group in Zarr. + root (zarr.Group): The root Zarr group where the dataset will be stored. + + """ + values = sparse_matrix["values"] + pairs = format_pairs(sparse_matrix) + + # Create a Zarr dataset for this specific gage + root.create_dataset("pairs", data=np.array(pairs), chunks=(10000,), dtype="int32") + root.array("values", data=np.array(values), chunks=(10000,), dtype="float32") + + +def create_sparse_MERIT_FLOW_TM( + cfg: DictConfig, edges: zarr.hierarchy.Group +) -> zarr.hierarchy.Group: + """ + Creating a sparse TM that maps MERIT basins to their reaches. Flow predictions are distributed + based on reach length/ total merit reach length + :param cfg: + :param edges: + :param huc_to_merit_TM: + :return: + """ + log.info("Using Edge COMIDs for TM") + COMIDs = np.unique(edges.merit_basin[:]) # already sorted + gage_coo_root = zarr.open_group(Path(cfg.create_TMs.MERIT.TM), mode="a") + merit_basin = edges.merit_basin[:] + river_graph_len = edges.len[:] + river_graph = {"values": [], "comid_idx": [], "edge_id_idx": []} + for comid_idx, basin_id in enumerate( + tqdm( + COMIDs, + desc="Creating a sparse TM Mapping MERIT basins to their edges", + ncols=140, + ascii=True, + ) + ): + col_indices = np.where(merit_basin == basin_id)[0] + total_length = np.sum(river_graph_len[col_indices]) + if total_length == 0: + print("Basin not found:", basin_id) + continue + proportions = river_graph_len[col_indices] / total_length + river_graph["comid_idx"].append(comid_idx) + river_graph["edge_id_idx"].append(col_indices.tolist()) + river_graph["values"].extend(proportions.tolist()) + create_coo_data(river_graph, gage_coo_root) + + def create_MERIT_FLOW_TM( cfg: DictConfig, edges: zarr.hierarchy.Group ) -> zarr.hierarchy.Group: @@ -113,19 +194,27 @@ def create_MERIT_FLOW_TM( for idx, proportion in zip(indices, proportions): column_index = np.where(river_graph_ids == river_graph_ids[idx])[0][0] data_np[i][column_index] = proportion - data_array = xr.DataArray( - data=data_np, - dims=["COMID", "EDGEID"], # Explicitly naming the dimensions - coords={"COMID": COMIDs, "EDGEID": river_graph_ids}, # Adding coordinates - ) - xr_dataset = xr.Dataset( - data_vars={"TM": data_array}, - attrs={"description": "MERIT -> Edge Transition Matrix"}, - ) - log.info("Writing MERIT TM to zarr store") - zarr_path = Path(cfg.create_TMs.MERIT.TM) - xr_dataset.to_zarr(zarr_path, mode="w") - # zarr_hierarchy = zarr.open_group(Path(cfg.create_TMs.MERIT.TM), mode="r") + + if cfg.create_TMs.MERIT.save_sparse: + log.info("Writing to sparse matrix") + gage_coo_root = zarr.open_group(Path(cfg.create_TMs.MERIT.TM), mode="a") + matrix = sparse.csr_matrix(data_np) + binsparse.write(gage_coo_root, "TM", matrix) + log.info("Sparse matrix written") + else: + data_array = xr.DataArray( + data=data_np, + dims=["COMID", "EDGEID"], # Explicitly naming the dimensions + coords={"COMID": COMIDs, "EDGEID": river_graph_ids}, # Adding coordinates + ) + xr_dataset = xr.Dataset( + data_vars={"TM": data_array}, + attrs={"description": "MERIT -> Edge Transition Matrix"}, + ) + log.info("Writing MERIT TM to zarr store") + zarr_path = Path(cfg.create_TMs.MERIT.TM) + xr_dataset.to_zarr(zarr_path, mode="w") + # zarr_hierarchy = zarr.open_group(Path(cfg.create_TMs.MERIT.TM), mode="r") def join_geospatial_data(cfg: DictConfig) -> gpd.GeoDataFrame: diff --git a/marquette/merit/create.py b/marquette/merit/create.py index 8fe5095..c6da455 100644 --- a/marquette/merit/create.py +++ b/marquette/merit/create.py @@ -31,6 +31,7 @@ from marquette.merit._TM_calculations import ( create_HUC_MERIT_TM, create_MERIT_FLOW_TM, + # create_sparse_MERIT_FLOW_TM, join_geospatial_data, ) @@ -237,4 +238,7 @@ def create_TMs(cfg: DictConfig, edges: zarr.Group) -> None: log.info("MERIT -> FLOWLINE data already exists in zarr format") else: log.info("Creating MERIT -> FLOWLINE TM") + # if cfg.create_TMs.MERIT.save_sparse: + # create_sparse_MERIT_FLOW_TM(cfg, edges) + # else: create_MERIT_FLOW_TM(cfg, edges) diff --git a/requirements.txt b/requirements.txt index 5db2daa..fdb9120 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +git+https://github.com/ivirshup/binsparse-python.git@main dask[complete] polars dask-expr