diff --git a/concatenator/attribute_handling.py b/concatenator/attribute_handling.py index b99fcab..b98343d 100644 --- a/concatenator/attribute_handling.py +++ b/concatenator/attribute_handling.py @@ -29,33 +29,32 @@ def regroup_coordinate_attribute(attribute_string: str) -> str: """ # Use the separator that's in the attribute string only if all separators in the string are the same. # Otherwise, we will use our own default separator. - whitespaces = re.findall(r'\s+', attribute_string) + whitespaces = re.findall(r"\s+", attribute_string) if len(set(whitespaces)) <= 1: new_sep = whitespaces[0] else: new_sep = COORD_DELIM return new_sep.join( - '/'.join(c.split(GROUP_DELIM))[1:] - for c - in attribute_string.split() # split on any whitespace + "/".join(c.split(GROUP_DELIM))[1:] + for c in attribute_string.split() # split on any whitespace ) -def flatten_coordinate_attribute_paths(dataset: netCDF4.Dataset, - var: netCDF4.Variable, - variable_name: str) -> None: +def flatten_coordinate_attribute_paths( + dataset: netCDF4.Dataset, var: netCDF4.Variable, variable_name: str +) -> None: """Flatten the paths of variables referenced in the coordinates attribute.""" - if 'coordinates' in var.ncattrs(): - coord_att = var.getncattr('coordinates') + if "coordinates" in var.ncattrs(): + coord_att = var.getncattr("coordinates") new_coord_att = _flatten_coordinate_attribute(coord_att) - dataset.variables[variable_name].setncattr('coordinates', new_coord_att) + dataset.variables[variable_name].setncattr("coordinates", new_coord_att) def _flatten_coordinate_attribute(attribute_string: str) -> str: - """Converts attributes that specify group membership via "/" to use new group delimiter, even for the root level. + """Converts attributes with "/" delimiters to use new group delimiter, even for the root level. Examples -------- @@ -73,15 +72,18 @@ def _flatten_coordinate_attribute(attribute_string: str) -> str: """ # Use the separator that's in the attribute string only if all separators in the string are the same. # Otherwise, we will use our own default separator. - whitespaces = re.findall(r'\s+', attribute_string) - if len(set(whitespaces)) <= 1: + whitespaces = re.findall(r"\s+", attribute_string) + if len(set(whitespaces)) == 1: new_sep = whitespaces[0] else: new_sep = COORD_DELIM # A new string is constructed. - return new_sep.join( - f'{GROUP_DELIM}{c.replace("/", GROUP_DELIM)}' - for c - in attribute_string.split() # split on any whitespace - ) + return new_sep.join(flatten_variable_path_str(item) for item in attribute_string.split()) + + +def flatten_variable_path_str(path_str: str) -> str: + """Converts a path with "/" delimiters to use new group delimiter, even for the root level.""" + new_path = path_str.replace("/", GROUP_DELIM) + + return f"{GROUP_DELIM}{new_path}" if not new_path.startswith(GROUP_DELIM) else new_path diff --git a/concatenator/stitchee.py b/concatenator/stitchee.py index d1f2c8d..a99aeb7 100644 --- a/concatenator/stitchee.py +++ b/concatenator/stitchee.py @@ -9,6 +9,7 @@ import xarray as xr from concatenator import GROUP_DELIM +from concatenator.attribute_handling import flatten_variable_path_str from concatenator.dimension_cleanup import remove_duplicate_dims from concatenator.file_ops import add_label_to_path from concatenator.group_handling import ( @@ -27,6 +28,7 @@ def stitchee( concat_method: str = "xarray-concat", concat_dim: str = "", concat_kwargs: dict | None = None, + variables_to_include: list[str] | None = None, logger: Logger = default_logger, ) -> str: """Concatenate netCDF data files along an existing dimension. @@ -35,8 +37,16 @@ def stitchee( ---------- files_to_concat : list[str] output_file : str - keep_tmp_files : bool + write_tmp_flat_concatenated : bool, optional + keep_tmp_files : bool, optional + concat_method : str, optional + Either 'xarray-concat' or 'xarray-combine' concat_dim : str, optional + concat_kwargs : dict, optional + Keyword arguments to pass through to the xarray concatenation method + variables_to_include : list[str], optional + Names of variables to include. All other variables are excluded from the result + logger : logging.Logger Returns @@ -59,6 +69,14 @@ def stitchee( "'concat_dim' was specified, but will not be used because xarray-combine method was selected." ) + # Convert variable names inputted to flattened versions + if variables_to_include is not None: + variables_to_include_flattened = [ + flatten_variable_path_str(v) for v in variables_to_include + ] + else: + variables_to_include_flattened = None + logger.info("Flattening all input files...") xrdataset_list = [] @@ -67,10 +85,21 @@ def stitchee( # The group structure is flattened. start_time = time.time() logger.info(" ..file %03d/%03d <%s>..", i + 1, num_input_files, filepath) - flat_dataset, coord_vars, _ = flatten_grouped_dataset( + flat_dataset, coord_vars, string_vars = flatten_grouped_dataset( nc.Dataset(filepath, "r"), filepath, ensure_all_dims_are_coords=True ) + if variables_to_include_flattened is not None: + variables_to_delete = [ + var_name + for var_name, _ in flat_dataset.variables.items() + if (var_name not in variables_to_include_flattened) + and (var_name not in coord_vars) + ] + + for var_name in variables_to_delete: + del flat_dataset.variables[var_name] + logger.info("Removing duplicate dimensions") flat_dataset = remove_duplicate_dims(flat_dataset) @@ -101,22 +130,24 @@ def stitchee( # coords='minimal', # compat='override') + # Establish default concatenation keyword arguments if not supplied as input. if concat_kwargs is None: concat_kwargs = {} + if "data_vars" not in concat_kwargs: + concat_kwargs["data_vars"] = "minimal" + if "coords" not in concat_kwargs: + concat_kwargs["coords"] = "minimal" + # Perform concatenation operation. if concat_method == "xarray-concat": combined_ds = xr.concat( xrdataset_list, dim=GROUP_DELIM + concat_dim, - data_vars="minimal", - coords="minimal", **concat_kwargs, ) elif concat_method == "xarray-combine": combined_ds = xr.combine_by_coords( xrdataset_list, - data_vars="minimal", - coords="minimal", **concat_kwargs, ) else: diff --git a/tests/test_concat.py b/tests/test_concat.py index 2ec0f82..71bbb48 100644 --- a/tests/test_concat.py +++ b/tests/test_concat.py @@ -12,6 +12,7 @@ import pytest from concatenator import concat_with_nco +from concatenator.attribute_handling import flatten_variable_path_str from concatenator.stitchee import stitchee @@ -37,13 +38,14 @@ def run_verification_with_stitchee( concat_method: str = "xarray-concat", record_dim_name: str = "mirror_step", concat_kwargs: dict | None = None, + variables_to_include: list[str] | None = None, ): output_path = str(self.__output_path.joinpath(output_name)) # type: ignore data_path = self.__test_data_path.joinpath(data_dir) # type: ignore input_files = [] for filepath in data_path.iterdir(): - if Path(filepath).suffix.lower() in (".nc", ".h5", ".hdf"): + if Path(filepath).suffix.lower() in (".nc", ".nc4", ".h5", ".hdf"): copied_input_new_path = self.__output_path / Path(filepath).name # type: ignore shutil.copyfile(filepath, copied_input_new_path) input_files.append(str(copied_input_new_path)) @@ -59,16 +61,27 @@ def run_verification_with_stitchee( concat_method=concat_method, concat_dim=record_dim_name, concat_kwargs=concat_kwargs, + variables_to_include=variables_to_include, ) - merged_dataset = nc.Dataset(output_path) - # Verify that the length of the record dimension in the concatenated file equals # the sum of the lengths across the input files length_sum = 0 for file in input_files: - length_sum += len(nc.Dataset(file).variables[record_dim_name]) - assert length_sum == len(merged_dataset.variables[record_dim_name]) + with nc.Dataset(file) as ds: + length_sum += ds.dimensions[flatten_variable_path_str(record_dim_name)].size + + with nc.Dataset(output_path) as merged_dataset: + if record_dim_name in merged_dataset.variables: + # Primary dimension is a root level variable + assert length_sum == len(merged_dataset.variables[record_dim_name]) + elif record_dim_name in merged_dataset.dimensions: + # Primary dimension is a root level dimension, but not a variable + assert length_sum == merged_dataset.dimensions[record_dim_name].size + else: + raise AttributeError( + "Unexpected condition, where primary record dimension is not at the root level." + ) def run_verification_with_nco(self, data_dir, output_name, record_dim_name="mirror_step"): output_path = str(self.__output_path.joinpath(output_name)) diff --git a/tests/test_group_handling.py b/tests/test_group_handling.py index 0515910..bd2c5c8 100644 --- a/tests/test_group_handling.py +++ b/tests/test_group_handling.py @@ -2,27 +2,61 @@ # pylint: disable=C0116, C0301 -from concatenator.attribute_handling import (_flatten_coordinate_attribute, - regroup_coordinate_attribute) +from concatenator.attribute_handling import ( + _flatten_coordinate_attribute, + regroup_coordinate_attribute, +) -def test_coordinate_attribute_flattening(): +def test_coordinate_attribute_flattening_with_no_leading_slash(): # Case with groups present and double spaces. - assert _flatten_coordinate_attribute( - "Time_and_Position/time Time_and_Position/instrument_fov_latitude Time_and_Position/instrument_fov_longitude" - ) == '__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude' + assert ( + _flatten_coordinate_attribute( + "Time_and_Position/time Time_and_Position/instrument_fov_latitude Time_and_Position/instrument_fov_longitude" + ) + == "__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude" + ) # Case with NO groups present and single spaces. - assert _flatten_coordinate_attribute( - "time longitude latitude ozone_profile_pressure ozone_profile_altitude" - ) == "__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude" + assert ( + _flatten_coordinate_attribute( + "time longitude latitude ozone_profile_pressure ozone_profile_altitude" + ) + == "__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude" + ) + + +def test_coordinate_attribute_flattening_with_a_leading_slash(): + # Case with groups present and double spaces. + assert ( + _flatten_coordinate_attribute( + "/Time_and_Position/time /Time_and_Position/instrument_fov_latitude /Time_and_Position/instrument_fov_longitude" + ) + == "__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude" + ) + + # Case with NO groups present and single spaces. + assert ( + _flatten_coordinate_attribute( + "/time /longitude /latitude /ozone_profile_pressure /ozone_profile_altitude" + ) + == "__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude" + ) def test_coordinate_attribute_regrouping(): # Case with groups present and double spaces. - assert regroup_coordinate_attribute( - '__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude') == "Time_and_Position/time Time_and_Position/instrument_fov_latitude Time_and_Position/instrument_fov_longitude" + assert ( + regroup_coordinate_attribute( + "__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude" + ) + == "Time_and_Position/time Time_and_Position/instrument_fov_latitude Time_and_Position/instrument_fov_longitude" + ) # Case with NO groups present and single spaces. - assert regroup_coordinate_attribute( - "__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude") == "time longitude latitude ozone_profile_pressure ozone_profile_altitude" + assert ( + regroup_coordinate_attribute( + "__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude" + ) + == "time longitude latitude ozone_profile_pressure ozone_profile_altitude" + )