From a78d74557809f6ae579eb5770d309c8cea4b1e79 Mon Sep 17 00:00:00 2001 From: Oliver Lomax Date: Fri, 11 Oct 2024 20:43:27 +0100 Subject: [PATCH] Integrate `pack_vector_fields` into `SphericalVector` Interpolation method. (#224) --- .../method/sphericalvector/SphericalVector.cc | 32 ++++--- src/atlas/util/PackVectorFields.cc | 10 ++ .../test_interpolation_spherical_vector.cc | 91 +++++++++++++++++++ 3 files changed, 122 insertions(+), 11 deletions(-) diff --git a/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc index ab5f573d8..65e4e41d0 100644 --- a/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc +++ b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc @@ -21,6 +21,7 @@ #include "atlas/runtime/Trace.h" #include "atlas/util/Constants.h" #include "atlas/util/Geometry.h" +#include "atlas/util/PackVectorFields.h" #include "eckit/config/LocalConfiguration.h" namespace atlas { @@ -95,10 +96,9 @@ void SphericalVector::do_setup(const FunctionSpace& source, const auto deltaAlpha = (alpha.first - alpha.second) * util::Constants::degreesToRadians(); - complexTriplets[dataIndex] = - ComplexTriplet{rowIndex, colIndex, - Complex{baseWeight * std::cos(deltaAlpha), - baseWeight * std::sin(deltaAlpha)}}; + complexTriplets[dataIndex] = ComplexTriplet{ + rowIndex, colIndex, + baseWeight * Complex{std::cos(deltaAlpha), std::sin(deltaAlpha)}}; realTriplets[dataIndex] = RealTriplet{rowIndex, colIndex, baseWeight}; } } @@ -120,9 +120,14 @@ void SphericalVector::do_execute(const FieldSet& sourceFieldSet, ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()"); ATLAS_ASSERT(sourceFieldSet.size() == targetFieldSet.size()); - for (auto i = 0; i < sourceFieldSet.size(); ++i) { - do_execute(sourceFieldSet[i], targetFieldSet[i], metadata); + const auto packedSourceFieldSet = util::pack_vector_fields(sourceFieldSet); + auto packedTargetFieldSet = util::pack_vector_fields(targetFieldSet); + + for (auto i = 0; i < packedSourceFieldSet.size(); ++i) { + do_execute(packedSourceFieldSet[i], packedTargetFieldSet[i], metadata); } + + util::unpack_vector_fields(packedTargetFieldSet, targetFieldSet); } void SphericalVector::do_execute(const Field& sourceField, Field& targetField, @@ -130,7 +135,7 @@ void SphericalVector::do_execute(const Field& sourceField, Field& targetField, ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()"); if (targetField.size() == 0) { - return; + return; } const auto fieldType = sourceField.metadata().getString("type", ""); @@ -156,9 +161,15 @@ void SphericalVector::do_execute_adjoint(FieldSet& sourceFieldSet, "atlas::interpolation::method::SphericalVector::do_execute_adjoint()"); ATLAS_ASSERT(sourceFieldSet.size() == targetFieldSet.size()); - for (auto i = 0; i < sourceFieldSet.size(); ++i) { - do_execute_adjoint(sourceFieldSet[i], targetFieldSet[i], metadata); + auto packedSourceFieldSet = util::pack_vector_fields(sourceFieldSet); + const auto packedTargetFieldSet = util::pack_vector_fields(targetFieldSet); + + for (auto i = 0; i < packedSourceFieldSet.size(); ++i) { + do_execute_adjoint(packedSourceFieldSet[i], packedTargetFieldSet[i], + metadata); } + + util::unpack_vector_fields(packedSourceFieldSet, sourceFieldSet); } void SphericalVector::do_execute_adjoint(Field& sourceField, @@ -168,7 +179,7 @@ void SphericalVector::do_execute_adjoint(Field& sourceField, "atlas::interpolation::method::SphericalVector::do_execute_adjoint()"); if (targetField.size() == 0) { - return; + return; } const auto fieldType = sourceField.metadata().getString("type", ""); @@ -192,7 +203,6 @@ template void SphericalVector::interpolate_vector_field(const Field& sourceField, Field& targetField, const MatMul& matMul) { - ATLAS_ASSERT_MSG(sourceField.variables() == 2 || sourceField.variables() == 3, "Vector field can only have 2 or 3 components."); diff --git a/src/atlas/util/PackVectorFields.cc b/src/atlas/util/PackVectorFields.cc index 640a1d46b..2734a1031 100644 --- a/src/atlas/util/PackVectorFields.cc +++ b/src/atlas/util/PackVectorFields.cc @@ -180,6 +180,15 @@ FieldSet pack_vector_fields(const FieldSet& fields, FieldSet packedFields) { componentFieldMetadataVector.push_back(componentFieldMetadata); vectorField.metadata().set("component_field_metadata", componentFieldMetadataVector); + + // If any component is dirty, the whole field is dirty. + if (vectorIndex == 0) { + vectorField.set_dirty(componentField.dirty()); + } else { + vectorField.set_dirty(vectorField.dirty() || componentField.dirty()); + } + + } return packedFields; } @@ -218,6 +227,7 @@ FieldSet unpack_vector_fields(const FieldSet& fields, FieldSet unpackedFields) { // Copy metadata. componentField.metadata() = componentFieldMetadata; + componentField.set_dirty(vectorField.dirty()); ++vectorIndex; } diff --git a/src/tests/interpolation/test_interpolation_spherical_vector.cc b/src/tests/interpolation/test_interpolation_spherical_vector.cc index fe48b47dd..398249d27 100644 --- a/src/tests/interpolation/test_interpolation_spherical_vector.cc +++ b/src/tests/interpolation/test_interpolation_spherical_vector.cc @@ -464,6 +464,97 @@ CASE("structured columns O96 vector interpolation (2d-field, 2-vector, hi-res)") testInterpolation((config)); } +CASE("separate vector field components") { + const auto sourceFunctionSpace = + FunctionSpaceFixtures::get("structured_columns"); + const auto targetFunctionSpace = + FunctionSpaceFixtures::get("cubedsphere_mesh"); + + auto sourceFieldSet = FieldSet{}; + auto targetFieldSet = FieldSet{}; + + const auto sourceLonLatView = + array::make_view(sourceFunctionSpace.lonlat()); + const auto targetLonLatView = + array::make_view(targetFunctionSpace.lonlat()); + + const auto createFieldView = [&](const FunctionSpace& functionSpace, + const std::string& name, + FieldSet& fieldSet) { + // Note: Vector field name can be anything that uniquely identifies field. + auto field = functionSpace.createField(option::name(name)); + field.metadata().set("vector_field_name", "wind"); + return array::make_view(fieldSet.add(field)); + }; + + auto uSourceView = createFieldView(sourceFunctionSpace, "u", sourceFieldSet); + auto vSourceView = createFieldView(sourceFunctionSpace, "v", sourceFieldSet); + const auto uTargetView = + createFieldView(targetFunctionSpace, "u", targetFieldSet); + const auto vTargetView = + createFieldView(targetFunctionSpace, "v", targetFieldSet); + + uSourceView.assign(0.); + vSourceView.assign(0.); + for (auto idx = idx_t{0}; idx < sourceFunctionSpace.size(); idx++) { + std::tie(uSourceView(idx), vSourceView(idx)) = + vortexHorizontal(sourceLonLatView(idx, 0), sourceLonLatView(idx, 1)); + } + + const auto interpScheme = + InterpSchemeFixtures::get("structured_linear_spherical"); + + const auto interp = + Interpolation(interpScheme, sourceFunctionSpace, targetFunctionSpace); + + interp.execute(sourceFieldSet, targetFieldSet); + targetFieldSet.haloExchange(); + + auto errorView = + createFieldView(targetFunctionSpace, "error", targetFieldSet); + + auto maxError = 0.; + for (auto idx = idx_t{0}; idx < targetFunctionSpace.size(); idx++) { + auto [uTrue, vTrue] = + vortexHorizontal(targetLonLatView(idx, 0), targetLonLatView(idx, 1)); + errorView(idx) = + std::hypot(uTrue - uTargetView(idx), vTrue - vTargetView(idx)); + maxError = std::max(maxError, errorView(idx)); + } + EXPECT_APPROX_EQ(maxError, 0., 0.00017); + + gmshOutput("vector_components_source.msh", sourceFieldSet); + gmshOutput("vector_components_target.msh", targetFieldSet); + + auto sourceAdjointFieldSet = FieldSet{}; + auto targetAdjointFieldSet = FieldSet{}; + + targetAdjointFieldSet.add(targetFieldSet["u"].clone()); + targetAdjointFieldSet.add(targetFieldSet["v"].clone()); + + targetAdjointFieldSet.adjointHaloExchange(); + + auto uSourceAdjointView = + createFieldView(sourceFunctionSpace, "u", sourceAdjointFieldSet); + auto vSourceAdjointView = + createFieldView(sourceFunctionSpace, "v", sourceAdjointFieldSet); + uSourceAdjointView.assign(0.); + vSourceAdjointView.assign(0.); + + // sourceAdjointFieldSet.set_dirty(false); + interp.execute_adjoint(sourceAdjointFieldSet, targetAdjointFieldSet); + + constexpr auto tinyNum = 1e-13; + const auto targetDotTarget = dotProduct(uTargetView, uTargetView) + + dotProduct(vTargetView, vTargetView); + const auto sourceDotSourceAdjoint = + dotProduct(uSourceView, uSourceAdjointView) + + dotProduct(vSourceView, vSourceAdjointView); + + const auto dotProdRatio = targetDotTarget / sourceDotSourceAdjoint; + EXPECT_APPROX_EQ(dotProdRatio, 1., tinyNum); +} + } // namespace test } // namespace atlas