From 852d7d39b97ef227d2fbca85fd1f38e8355bd65e Mon Sep 17 00:00:00 2001
From: t-reents <timo.reents1@gmail.com>
Date: Wed, 3 Jul 2024 17:31:08 +0200
Subject: [PATCH] Simplify `get_builder_from_protocol` in
 `ProjwfcBandsWorkChain`

This commit mainly simplifies the current version of the
`get_builder_from_protocol` method in `ProjwfcBandsWorkChain`.

Moreover, it adds support for overrides containing standard Python datatypes,
e.g. `kpoints_distance` specified as a float`.
---
 .../workflows/projwfcbands.py                 | 24 ++++---------------
 1 file changed, 4 insertions(+), 20 deletions(-)

diff --git a/src/aiida_wannier90_workflows/workflows/projwfcbands.py b/src/aiida_wannier90_workflows/workflows/projwfcbands.py
index 3117321..8d50080 100644
--- a/src/aiida_wannier90_workflows/workflows/projwfcbands.py
+++ b/src/aiida_wannier90_workflows/workflows/projwfcbands.py
@@ -107,7 +107,6 @@ def get_builder_from_protocol(  # pylint: disable=arguments-differ
         """
         from aiida_wannier90_workflows.utils.workflows.builder.submit import (
             recursive_merge_builder,
-            recursive_merge_container,
         )
 
         type_check(pw_code, (str, int, orm.Code))
@@ -116,13 +115,9 @@ def get_builder_from_protocol(  # pylint: disable=arguments-differ
         type_check(protocol, str, allow_none=True)
         type_check(overrides, dict, allow_none=True)
 
-        # Prepare workchain builder
+        # # Prepare workchain builder
         builder = cls.get_builder()
 
-        protocol_inputs = cls.get_protocol_inputs(
-            protocol=protocol, overrides=overrides
-        )
-
         projwfc_overrides = None
         if overrides:
             projwfc_overrides = overrides.pop("projwfc", None)
@@ -137,25 +132,14 @@ def get_builder_from_protocol(  # pylint: disable=arguments-differ
 
         # By default do not run relax
         pwbands_builder.pop("relax", None)
-        inputs = pwbands_builder._inputs(prune=True)  # pylint: disable=protected-access
 
         projwfc_builder = ProjwfcBaseWorkChain.get_builder_from_protocol(
             projwfc_code, protocol=protocol, overrides=projwfc_overrides
         )
+        projwfc_builder.pop("clean_workdir", None)
 
-        inputs["projwfc"] = projwfc_builder._inputs(  # pylint: disable=protected-access
-            prune=True
-        )
-        inputs["projwfc"].pop("clean_workdir", None)
-
-        # Need to convert `clean_workdir` to `orm.Bool`
-        if "clean_workdir" in protocol_inputs:
-            protocol_inputs["clean_workdir"] = orm.Bool(
-                protocol_inputs["clean_workdir"]
-            )
-
-        inputs = recursive_merge_container(inputs, protocol_inputs)
-        builder = recursive_merge_builder(builder, inputs)
+        builder.projwfc = projwfc_builder
+        builder = recursive_merge_builder(builder, pwbands_builder)
 
         return builder