diff --git a/src/pymatgen/analysis/pourbaix_diagram.py b/src/pymatgen/analysis/pourbaix_diagram.py index e5e9c003e79..15c9a407f20 100644 --- a/src/pymatgen/analysis/pourbaix_diagram.py +++ b/src/pymatgen/analysis/pourbaix_diagram.py @@ -31,11 +31,16 @@ from pymatgen.util.string import Stringify if TYPE_CHECKING: - from typing import Any + from collections.abc import Sequence + from typing import Any, ClassVar, Literal import matplotlib.pyplot as plt + from numpy.typing import NDArray from typing_extensions import Self + from pymatgen.core import DummySpecies, Species + from pymatgen.entries.computed_entries import ComputedStructureEntry + __author__ = "Sai Jayaraman" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "0.4" @@ -62,12 +67,15 @@ logger = logging.getLogger(__name__) -PREFAC = 0.0591 +PREFAC: float = 0.0591 # ln(10) * RT/nF in Nernst Equation # TODO: Revise to more closely reflect PDEntry, invoke from energy/composition + # TODO: PourbaixEntries depend implicitly on having entry energies be # formation energies, should be a better way to get from raw energies + + # TODO: uncorrected_energy is a bit of a misnomer, but not sure what to rename class PourbaixEntry(MSONable, Stringify): """ @@ -82,7 +90,12 @@ class PourbaixEntry(MSONable, Stringify): work. This may be changed to be more flexible in the future. """ - def __init__(self, entry, entry_id=None, concentration=1e-6): + def __init__( + self, + entry: ComputedEntry | ComputedStructureEntry, + entry_id: str | None = None, + concentration: float = 1e-6, + ) -> None: """ Args: entry (ComputedEntry | ComputedStructureEntry | PDEntry | IonEntry): An entry object @@ -100,29 +113,42 @@ def __init__(self, entry, entry_id=None, concentration=1e-6): self.charge = 0 self.uncorrected_energy = entry.energy if entry_id is not None: - self.entry_id = entry_id + self.entry_id: str | None = entry_id elif getattr(entry, "entry_id", None): self.entry_id = entry.entry_id else: self.entry_id = None + def __repr__(self) -> str: + energy, npH, nPhi, nH2O, entry_id = ( + self.energy, + self.npH, + self.nPhi, + self.nH2O, + self.entry_id, + ) + return ( + f"{type(self).__name__}({self.entry.composition} with {energy=:.4f}, {npH=}, " + f"{nPhi=}, {nH2O=}, {entry_id=})" + ) + @property - def npH(self): + def npH(self) -> float: """The number of H.""" return self.entry.composition.get("H", 0) - 2 * self.entry.composition.get("O", 0) @property - def nH2O(self): + def nH2O(self) -> float: """The number of H2O.""" return self.entry.composition.get("O", 0) @property - def nPhi(self): + def nPhi(self) -> float: """The number of electrons.""" return self.npH - self.charge @property - def name(self): + def name(self) -> str: """The entry's name.""" if self.phase_type == "Solid": return f"{self.entry.reduced_formula}(s)" @@ -130,22 +156,22 @@ def name(self): return self.entry.name @property - def energy(self): + def energy(self) -> float: """Total energy of the Pourbaix entry (at pH, V = 0 vs. SHE).""" # Note: this implicitly depends on formation energies as input return self.uncorrected_energy + self.conc_term - (MU_H2O * self.nH2O) @property - def energy_per_atom(self): + def energy_per_atom(self) -> float: """Energy per atom of the Pourbaix entry.""" return self.energy / self.composition.num_atoms @property - def elements(self): + def elements(self) -> list[Element | Species | DummySpecies]: """Elements in the entry.""" return self.entry.elements - def energy_at_conditions(self, pH, V): + def energy_at_conditions(self, pH: float, V: float) -> float: """Get free energy for a given pH and V. Args: @@ -157,7 +183,7 @@ def energy_at_conditions(self, pH, V): """ return self.energy + self.npH * PREFAC * pH + self.nPhi * V - def get_element_fraction(self, element): + def get_element_fraction(self, element: Element | str) -> float: """Get the elemental fraction of a given non-OH element. Args: @@ -170,13 +196,13 @@ def get_element_fraction(self, element): return self.composition.get(element) * self.normalization_factor @property - def normalized_energy(self): + def normalized_energy(self) -> float: """Energy normalized by number of non H or O atoms, e.g. for Zn2O6, energy / 2 or for AgTe3(OH)3, energy / 4. """ return self.energy * self.normalization_factor - def normalized_energy_at_conditions(self, pH, V): + def normalized_energy_at_conditions(self, pH: float, V: float) -> float: """Energy at an electrochemical condition, compatible with numpy arrays for pH/V input. @@ -190,7 +216,7 @@ def normalized_energy_at_conditions(self, pH, V): return self.energy_at_conditions(pH, V) * self.normalization_factor @property - def conc_term(self): + def conc_term(self) -> float: """The concentration contribution to the free energy. Should only be present when there are ions in the entry. """ @@ -202,14 +228,17 @@ def as_dict(self): Note that the pH, voltage, H2O factors are always calculated when constructing a PourbaixEntry object. """ - dct = {"@module": type(self).__module__, "@class": type(self).__name__} + dct = { + "@module": type(self).__module__, + "@class": type(self).__name__, + "entry": self.entry.as_dict(), + "concentration": self.concentration, + "entry_id": self.entry_id, + } if isinstance(self.entry, IonEntry): dct["entry_type"] = "Ion" else: dct["entry_type"] = "Solid" - dct["entry"] = self.entry.as_dict() - dct["concentration"] = self.concentration - dct["entry_id"] = self.entry_id return dct @classmethod @@ -224,17 +253,17 @@ def from_dict(cls, dct: dict) -> Self: return cls(entry, entry_id, concentration) @property - def normalization_factor(self): + def normalization_factor(self) -> float: """Sum of number of atoms minus the number of H and O in composition.""" return 1.0 / (self.num_atoms - self.composition.get("H", 0) - self.composition.get("O", 0)) @property - def composition(self): + def composition(self) -> Composition: """Composition.""" return self.entry.composition @property - def num_atoms(self): + def num_atoms(self) -> float: """Number of atoms in current formula. Useful for normalization.""" return self.composition.num_atoms @@ -245,40 +274,27 @@ def to_pretty_string(self) -> str: return self.entry.name - def __repr__(self): - energy, npH, nPhi, nH2O, entry_id = ( - self.energy, - self.npH, - self.nPhi, - self.nH2O, - self.entry_id, - ) - return ( - f"{type(self).__name__}({self.entry.composition} with {energy=:.4f}, {npH=}, " - f"{nPhi=}, {nH2O=}, {entry_id=})" - ) - class MultiEntry(PourbaixEntry): """PourbaixEntry-like object for constructing multi-elemental Pourbaix diagrams.""" - def __init__(self, entry_list, weights=None): + def __init__(self, entry_list: Sequence[PourbaixEntry], weights: list[float] | None = None) -> None: """Initialize a MultiEntry. Args: - entry_list ([PourbaixEntry]): List of component PourbaixEntries - weights ([float]): Weights associated with each entry. Default is None + entry_list (Sequence[PourbaixEntry]): Component PourbaixEntries. + weights (list[float]): Weights associated with each entry. Default is None """ self.weights = weights or [1.0] * len(entry_list) - self.entry_list = entry_list + self.entry_list = list(entry_list) - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: """ Because most of the attributes here are just weighted averages of the entry_list, we save some space by having a set of conditionals to define the attributes. """ # Attributes that are weighted averages of entry attributes - if attr in [ + if attr in { "energy", "npH", "nH2O", @@ -287,7 +303,7 @@ def __getattr__(self, attr): "composition", "uncorrected_energy", "elements", - ]: + }: # TODO: Composition could be changed for compat with sum start = Composition() if attr == "composition" else 0 weighted_values = ( @@ -296,18 +312,13 @@ def __getattr__(self, attr): return sum(weighted_values, start) # Attributes that are just lists of entry attributes - if attr in ["entry_id", "phase_type"]: + if attr in {"entry_id", "phase_type"}: return [getattr(entry, attr) for entry in self.entry_list] # normalization_factor, num_atoms should work from superclass return self.__getattribute__(attr) - @property - def name(self): - """MultiEntry name, i.e. the name of each entry joined by ' + '.""" - return " + ".join(entry.name for entry in self.entry_list) - - def __repr__(self): + def __repr__(self) -> str: energy, npH, nPhi, nH2O, entry_id = ( self.energy, self.npH, @@ -318,7 +329,12 @@ def __repr__(self): cls_name, species = type(self).__name__, self.name return f"Pourbaix{cls_name}({energy=:.4f}, {npH=}, {nPhi=}, {nH2O=}, {entry_id=}, {species=})" - def as_dict(self): + @property + def name(self) -> str: + """MultiEntry name, i.e. the name of each entry joined by ' + '.""" + return " + ".join(entry.name for entry in self.entry_list) + + def as_dict(self) -> dict[str, Any]: """Get MSONable dict.""" return { "@module": type(self).__module__, @@ -367,6 +383,9 @@ def __init__(self, ion: Ion, energy: float, name: str | None = None, attribute=N name = name or self.ion.reduced_formula super().__init__(composition=ion.composition, energy=energy, name=name, attribute=attribute) + def __repr__(self) -> str: + return f"IonEntry : {self.composition} with energy = {self.energy:.4f}" + @classmethod def from_dict(cls, dct: dict) -> Self: """Get an IonEntry object from a dict.""" @@ -377,34 +396,30 @@ def from_dict(cls, dct: dict) -> Self: dct.get("attribute"), ) - def as_dict(self): + def as_dict(self) -> dict[Literal["ion", "energy", "name"], Any]: """Create a dict of composition, energy, and ion name.""" return {"ion": self.ion.as_dict(), "energy": self.energy, "name": self.name} - def __repr__(self): - return f"IonEntry : {self.composition} with energy = {self.energy:.4f}" - -def ion_or_solid_comp_object(formula): - """Get an Ion or Composition object given a formula. +def ion_or_solid_comp_object(formula: str) -> Composition | Ion: + """Get an Ion or Composition from a formula. Args: - formula: String formula. Eg. of ion: NaOH(aq), Na[+]; - Eg. of solid: Fe2O3(s), Fe(s), Na2O + formula (str): Formula. E.g. of ion: NaOH(aq), Na[+], Na+; + E.g. of solid: Fe2O3(s), Fe(s), Na2O. Returns: - Composition/Ion object + Composition/Ion object. """ - if re.match(r"\[([^\[\]]+)\]|\(aq\)", formula): - comp_obj = Ion.from_formula(formula) - elif re.search(r"\(s\)", formula): - comp_obj = Composition(formula[:-3]) - else: - comp_obj = Composition(formula) - return comp_obj + # Formula for ion + if formula.endswith("(aq)") or re.search(r"\[.*\]$", formula) or "-" in formula or "+" in formula: + return Ion.from_formula(formula) + # Formula for solid + if formula.endswith("(s)"): + return Composition(formula[:-3]) -ELEMENTS_HO = {Element("H"), Element("O")} + return Composition(formula) # TODO: the solids filter breaks some of the functionality of the @@ -418,6 +433,8 @@ def ion_or_solid_comp_object(formula): class PourbaixDiagram(MSONable): """Create a Pourbaix diagram from entries.""" + elements_ho: ClassVar[set[Element]] = {Element("H"), Element("O")} + def __init__( self, entries: list[PourbaixEntry] | list[MultiEntry], @@ -425,7 +442,7 @@ def __init__( conc_dict: dict[str, float] | None = None, filter_solids: bool = True, nproc: int | None = None, - ): + ) -> None: """ Args: entries ([PourbaixEntry] or [MultiEntry]): Entries list @@ -451,7 +468,7 @@ def __init__( # Get non-OH elements self.pbx_elts = list( - set(itertools.chain.from_iterable([entry.composition.elements for entry in entries])) - ELEMENTS_HO + set(itertools.chain.from_iterable([entry.composition.elements for entry in entries])) - self.elements_ho ) self.dim = len(self.pbx_elts) - 1 @@ -463,7 +480,7 @@ def __init__( self._unprocessed_entries = single_entries self._filtered_entries = single_entries self._conc_dict = None - self._elt_comp = {k: v for k, v in entries[0].composition.items() if k not in ELEMENTS_HO} + self._elt_comp = {k: v for k, v in entries[0].composition.items() if k not in self.elements_ho} self._multi_element = True # Process single entry inputs @@ -483,7 +500,7 @@ def __init__( # If a conc_dict is specified, override individual entry concentrations for entry in ion_entries: - ion_elts = list(set(entry.elements) - ELEMENTS_HO) + ion_elts = list(set(entry.elements) - self.elements_ho) # TODO: the logic here for ion concentration setting is in two # places, in PourbaixEntry and here, should be consolidated if len(ion_elts) == 1: @@ -512,15 +529,15 @@ def __init__( self._stable_domains, self._stable_domain_vertices = self.get_pourbaix_domains(self._processed_entries) - def _convert_entries_to_points(self, pourbaix_entries): + def _convert_entries_to_points(self, pourbaix_entries: list[PourbaixEntry]) -> NDArray: """ Args: - pourbaix_entries ([PourbaixEntry]): list of Pourbaix entries + pourbaix_entries (list[PourbaixEntry]): Pourbaix entries to process into vectors in nph-nphi-composition space. Returns: - list of vectors, [[nph, nphi, e0, x1, x2, ..., xn-1]] - corresponding to each entry in nph-nphi-composition space + NDAarray: vectors as [[nph, nphi, e0, x1, x2, ..., xn-1]] + corresponding to each entry in nph-nphi-composition space """ vecs = [ [entry.npH, entry.nPhi, entry.energy] + [entry.composition.get(elt) for elt in self.pbx_elts[:-1]] @@ -531,15 +548,18 @@ def _convert_entries_to_points(self, pourbaix_entries): vecs *= norms return vecs - def _get_hull_in_nph_nphi_space(self, entries) -> tuple[list[PourbaixEntry], list[Simplex]]: + def _get_hull_in_nph_nphi_space( + self, + entries: list[PourbaixEntry], + ) -> tuple[list[PourbaixEntry], list[Simplex]]: """Generate convex hull of Pourbaix diagram entries in composition, npH, and nphi space. This enables filtering of multi-entries such that only compositionally stable combinations of entries are included. Args: - entries ([PourbaixEntry]): list of PourbaixEntries to construct - the convex hull + entries (list[PourbaixEntry]): PourbaixEntries to construct + the convex hull. Returns: tuple[list[PourbaixEntry], list[Simplex]]: PourbaixEntry list and stable @@ -588,11 +608,15 @@ def _get_hull_in_nph_nphi_space(self, entries) -> tuple[list[PourbaixEntry], lis return min_entries, valid_facets - def _preprocess_pourbaix_entries(self, entries, nproc=None): + def _preprocess_pourbaix_entries( + self, + entries: list[PourbaixEntry], + nproc: int | None = None, + ) -> list[MultiEntry]: """Generate multi-entries for Pourbaix diagram. Args: - entries ([PourbaixEntry]): list of PourbaixEntries to preprocess + entries (list[PourbaixEntry]): PourbaixEntries to preprocess into MultiEntries nproc (int): number of processes to be used in parallel treatment of entry combos @@ -605,23 +629,23 @@ def _preprocess_pourbaix_entries(self, entries, nproc=None): min_entries, valid_facets = self._get_hull_in_nph_nphi_space(entries) - combos = [] + combos: list[list[frozenset]] = [] for facet in valid_facets: for idx in range(1, self.dim + 2): - these_combos = [] + these_combos: list[frozenset] = [] for combo in itertools.combinations(facet, idx): these_entries = [min_entries[i] for i in combo] these_combos.append(frozenset(these_entries)) combos.append(these_combos) - all_combos = set(itertools.chain.from_iterable(combos)) + all_combos: set | list = set(itertools.chain.from_iterable(combos)) - list_combos = [] - for idx in all_combos: - list_combos.append(list(idx)) + list_combos: list = [] + for combo in all_combos: + list_combos.append(list(combo)) all_combos = list_combos - multi_entries = [] + multi_entries: list = [] # Parallel processing of multi-entry generation if nproc is not None: @@ -637,7 +661,11 @@ def _preprocess_pourbaix_entries(self, entries, nproc=None): return multi_entries - def _generate_multielement_entries(self, entries, nproc=None): + def _generate_multielement_entries( + self, + entries: list[PourbaixEntry], + nproc: int | None = None, + ) -> list[MultiEntry]: """ Create entries for multi-element Pourbaix construction. @@ -655,13 +683,13 @@ def _generate_multielement_entries(self, entries, nproc=None): total_comp = Composition(self._elt_comp) # generate all combinations of compounds that have all elements - entry_combos = [itertools.combinations(entries, idx + 1) for idx in range(n_elems)] - entry_combos = itertools.chain.from_iterable(entry_combos) + entry_combos: list = [itertools.combinations(entries, idx + 1) for idx in range(n_elems)] + entry_combos = list(itertools.chain.from_iterable(entry_combos)) - entry_combos = filter(lambda x: total_comp < MultiEntry(x).composition, entry_combos) + entry_combos = list(filter(lambda x: total_comp < MultiEntry(x).composition, entry_combos)) # Generate and filter entries - processed_entries = [] + processed_entries: list = [] total = sum(comb(len(entries), idx + 1) for idx in range(n_elems)) if total > 1e6: warnings.warn( @@ -676,15 +704,17 @@ def _generate_multielement_entries(self, entries, nproc=None): processed_entries = list(filter(bool, processed_entries)) # Serial processing of multi-entry generation else: - for entry_combo in entry_combos: - processed_entry = self.process_multientry(entry_combo, total_comp) + for combo in entry_combos: + processed_entry = self.process_multientry(combo, total_comp) if processed_entry is not None: processed_entries.append(processed_entry) return processed_entries @staticmethod - def process_multientry(entry_list, prod_comp, coeff_threshold=1e-4): + def process_multientry( + entry_list: Sequence, prod_comp: Composition, coeff_threshold: float = 1e-4 + ) -> MultiEntry | None: """Static method for finding a multientry based on a list of entries and a product composition. Essentially checks to see if a valid aqueous @@ -693,7 +723,7 @@ def process_multientry(entry_list, prod_comp, coeff_threshold=1e-4): with weights according to the coefficients if so. Args: - entry_list ([Entry]): list of entries from which to + entry_list (Sequence[Entry]): Entries from which to create a MultiEntry prod_comp (Composition): composition constraint for setting weights of MultiEntry @@ -722,7 +752,10 @@ def process_multientry(entry_list, prod_comp, coeff_threshold=1e-4): return None @staticmethod - def get_pourbaix_domains(pourbaix_entries, limits=None): + def get_pourbaix_domains( + pourbaix_entries: list[PourbaixEntry], + limits: list[list[float]] | None = None, + ) -> tuple[dict, dict]: """Get a set of Pourbaix stable domains (i.e. polygons) in pH-V space from a list of pourbaix_entries. @@ -736,25 +769,27 @@ def get_pourbaix_domains(pourbaix_entries, limits=None): points. Args: - pourbaix_entries ([PourbaixEntry]): Pourbaix entries + pourbaix_entries (list[PourbaixEntry]): Pourbaix entries with which to construct stable Pourbaix domains - limits ([[float]]): limits in which to do the pourbaix + limits (list[list[float]]): limits in which to do the pourbaix analysis Returns: - Returns a dict of the form {entry: [boundary_points]}. - The list of boundary points are the sides of the N-1 - dim polytope bounding the allowable ph-V range of each entry. + tuple[dict[PourbaixEntry, list], dict[PourbaixEntry, NDArray]: + The first dict is of form: {entry: [boundary_points]}. + The list of boundary points are the sides of the N-1 + dim polytope bounding the allowable ph-V range of each entry. """ if limits is None: limits = [[-2, 16], [-4, 4]] # Get hyperplanes - hyperplanes = [ - np.array([-PREFAC * entry.npH, -entry.nPhi, 0, -entry.energy]) * entry.normalization_factor - for entry in pourbaix_entries - ] - hyperplanes = np.array(hyperplanes) + hyperplanes = np.array( + [ + np.array([-PREFAC * entry.npH, -entry.nPhi, 0, -entry.energy]) * entry.normalization_factor + for entry in pourbaix_entries + ] + ) hyperplanes[:, 2] = 1 max_contribs = np.max(np.abs(hyperplanes), axis=0) @@ -773,7 +808,7 @@ def get_pourbaix_domains(pourbaix_entries, limits=None): hs_int = HalfspaceIntersection(hs_hyperplanes, np.array(interior_point)) # organize the boundary points by entry - pourbaix_domains = {entry: [] for entry in pourbaix_entries} + pourbaix_domains: dict[PourbaixEntry, list] = {entry: [] for entry in pourbaix_entries} for intersection, facet in zip(hs_int.intersections, hs_int.dual_facets, strict=True): for v in facet: if v < len(pourbaix_entries): @@ -782,18 +817,21 @@ def get_pourbaix_domains(pourbaix_entries, limits=None): # Remove entries with no Pourbaix region pourbaix_domains = {k: v for k, v in pourbaix_domains.items() if v} - pourbaix_domain_vertices = {} + pourbaix_domain_vertices: dict[PourbaixEntry, NDArray[float]] = {} for entry, points in pourbaix_domains.items(): points = np.array(points)[:, :2] # Initial sort to ensure consistency points = points[np.lexsort(np.transpose(points))] - center = np.mean(points, axis=0) - points_centered = points - center + center: NDArray[float] = np.mean(points, axis=0) + points_centered: NDArray[float] = points - center # Sort points by cross product of centered points, # isn't strictly necessary but useful for plotting tools - points_centered = sorted(points_centered, key=cmp_to_key(lambda x, y: x[0] * y[1] - x[1] * y[0])) + points_centered = sorted( + points_centered, + key=cmp_to_key(lambda x, y: x[0] * y[1] - x[1] * y[0]), # type: ignore[index] + ) points = points_centered + center # Create simplices corresponding to Pourbaix boundary @@ -803,7 +841,7 @@ def get_pourbaix_domains(pourbaix_entries, limits=None): return pourbaix_domains, pourbaix_domain_vertices - def find_stable_entry(self, pH, V): + def find_stable_entry(self, pH: float, V: float) -> PourbaixEntry: """Find stable entry at a pH,V condition. Args: @@ -816,15 +854,20 @@ def find_stable_entry(self, pH, V): energies_at_conditions = [entry.normalized_energy_at_conditions(pH, V) for entry in self.stable_entries] return self.stable_entries[np.argmin(energies_at_conditions)] - def get_decomposition_energy(self, entry, pH, V): + def get_decomposition_energy( + self, + entry: PourbaixEntry, + pH: float, + V: float, + ) -> NDArray: """Find decomposition to most stable entries in eV/atom, supports vectorized inputs for pH and V. Args: entry (PourbaixEntry): PourbaixEntry corresponding to compound to find the decomposition for - pH (float, list[float]): pH at which to find the decomposition - V (float, list[float]): voltage at which to find the decomposition + pH (float): pH at which to find the decomposition + V (float): voltage at which to find the decomposition Returns: Decomposition energy for the entry, i.e. the energy above @@ -833,7 +876,7 @@ def get_decomposition_energy(self, entry, pH, V): # Check composition consistency between entry and Pourbaix diagram: pbx_comp = Composition(self._elt_comp).fractional_composition entry_pbx_comp = Composition( - {elt: coeff for elt, coeff in entry.composition.items() if elt not in ELEMENTS_HO} + {elt: coeff for elt, coeff in entry.composition.items() if elt not in self.elements_ho} ).fractional_composition if entry_pbx_comp != pbx_comp: raise ValueError("Composition of stability entry does not match Pourbaix Diagram") @@ -846,7 +889,7 @@ def get_decomposition_energy(self, entry, pH, V): decomposition_energy /= entry.composition.num_atoms return decomposition_energy - def get_hull_energy(self, pH: float | list[float], V: float | list[float]) -> np.ndarray: + def get_hull_energy(self, pH: float | list[float], V: float | list[float]) -> NDArray: """Get the minimum energy of the Pourbaix "basin" that is formed from the stable Pourbaix planes. Vectorized. @@ -860,7 +903,7 @@ def get_hull_energy(self, pH: float | list[float], V: float | list[float]) -> np all_gs = np.array([entry.normalized_energy_at_conditions(pH, V) for entry in self.stable_entries]) return np.min(all_gs, axis=0) - def get_stable_entry(self, pH, V): + def get_stable_entry(self, pH: float, V: float) -> PourbaixEntry | MultiEntry: """Get the stable entry at a given pH, V condition. Args: @@ -875,26 +918,26 @@ def get_stable_entry(self, pH, V): return self.stable_entries[np.argmin(all_gs)] @property - def stable_entries(self): + def stable_entries(self) -> list: """The stable entries in the Pourbaix diagram.""" return list(self._stable_domains) @property - def unstable_entries(self): + def unstable_entries(self) -> list: """All unstable entries in the Pourbaix diagram.""" return [entry for entry in self.all_entries if entry not in self.stable_entries] @property - def all_entries(self): + def all_entries(self) -> list: """All entries used to generate the Pourbaix diagram.""" return self._processed_entries @property - def unprocessed_entries(self): + def unprocessed_entries(self) -> list: """Unprocessed entries.""" return self._unprocessed_entries - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """Get MSONable dict.""" return { "@module": type(self).__module__, @@ -926,14 +969,33 @@ def from_dict(cls, dct: dict) -> Self: class PourbaixPlotter: """A plotter class for phase diagrams.""" - def __init__(self, pourbaix_diagram): + def __init__(self, pourbaix_diagram: PourbaixDiagram) -> None: """ Args: pourbaix_diagram (PourbaixDiagram): A PourbaixDiagram object. """ self._pbx = pourbaix_diagram - def show(self, *args, **kwargs): + @staticmethod + def _generate_entry_label(entry: PourbaixEntry | MultiEntry) -> str: + """ + Generates a label for the Pourbaix plotter. + + Args: + entry (PourbaixEntry or MultiEntry): entry to get a label for + """ + if isinstance(entry, MultiEntry): + return " + ".join(entry.name for entry in entry.entry_list) + + # TODO - a more elegant solution could be added later to Stringify + # for example, the pattern re.sub(r"([-+][\d\.]*)", r"$^{\1}$", ) + # will convert B(OH)4- to B(OH)$_4^-$. + # for this to work, the ion's charge always must be written AFTER + # the sign (e.g., Fe+2 not Fe2+) + string = entry.to_latex_string() + return re.sub(r"()\[([^)]*)\]", r"\1$^{\2}$", string) + + def show(self, *args, **kwargs) -> None: """Show the Pourbaix plot. Args: @@ -1003,7 +1065,7 @@ def get_pourbaix_plot( if label_domains: ax.annotate( - generate_entry_label(entry), + self._generate_entry_label(entry), center, ha="center", va="center", @@ -1032,12 +1094,12 @@ def plot_entry_stability( Args: entry (Any): The entry to plot stability for. - pH_range (tuple[float, float], optional): pH range for the plot. Defaults to (-2, 16). - pH_resolution (int, optional): pH resolution. Defaults to 100. - V_range (tuple[float, float], optional): Voltage range for the plot. Defaults to (-3, 3). - V_resolution (int, optional): Voltage resolution. Defaults to 100. - e_hull_max (float, optional): Maximum energy above the hull. Defaults to 1. - cmap (str, optional): Colormap for the plot. Defaults to "RdYlBu_r". + pH_range (tuple[float, float]): pH range for the plot. Defaults to (-2, 16). + pH_resolution (int): pH resolution. Defaults to 100. + V_range (tuple[float, float]): Voltage range for the plot. Defaults to (-3, 3). + V_resolution (int): Voltage resolution. Defaults to 100. + e_hull_max (float): Maximum energy above the hull. Defaults to 1. + cmap (str): Colormap for the plot. Defaults to "RdYlBu_r". ax (Axes, optional): Existing matplotlib Axes object for plotting. Defaults to None. **kwargs (Any): Additional keyword arguments passed to `get_pourbaix_plot`. @@ -1056,7 +1118,7 @@ def plot_entry_stability( # Plot stability map cax = ax.pcolor(pH, V, stability, cmap=cmap, vmin=0, vmax=e_hull_max) cbar = ax.figure.colorbar(cax) - cbar.set_label(f"Stability of {generate_entry_label(entry)} (eV/atom)") + cbar.set_label(f"Stability of {self._generate_entry_label(entry)} (eV/atom)") # Set ticklabels # ticklabels = [t.get_text() for t in cbar.ax.get_yticklabels()] @@ -1065,7 +1127,7 @@ def plot_entry_stability( return ax - def domain_vertices(self, entry): + def domain_vertices(self, entry) -> list: """Get the vertices of the Pourbaix domain. Args: @@ -1075,22 +1137,3 @@ def domain_vertices(self, entry): list of vertices """ return self._pbx._stable_domain_vertices[entry] - - -def generate_entry_label(entry): - """ - Generates a label for the Pourbaix plotter. - - Args: - entry (PourbaixEntry or MultiEntry): entry to get a label for - """ - if isinstance(entry, MultiEntry): - return " + ".join(entry.name for entry in entry.entry_list) - - # TODO - a more elegant solution could be added later to Stringify - # for example, the pattern re.sub(r"([-+][\d\.]*)", r"$^{\1}$", ) - # will convert B(OH)4- to B(OH)$_4^-$. - # for this to work, the ion's charge always must be written AFTER - # the sign (e.g., Fe+2 not Fe2+) - string = entry.to_latex_string() - return re.sub(r"()\[([^)]*)\]", r"\1$^{\2}$", string) diff --git a/src/pymatgen/io/ase.py b/src/pymatgen/io/ase.py index f3cf8516d6d..308128c798b 100644 --- a/src/pymatgen/io/ase.py +++ b/src/pymatgen/io/ase.py @@ -6,7 +6,6 @@ from __future__ import annotations import warnings -from collections.abc import Iterable from importlib.metadata import PackageNotFoundError from typing import TYPE_CHECKING @@ -18,7 +17,7 @@ try: from ase.atoms import Atoms from ase.calculators.singlepoint import SinglePointDFTCalculator - from ase.constraints import FixAtoms + from ase.constraints import FixAtoms, FixCartesian from ase.io.jsonio import decode, encode from ase.spacegroup import Spacegroup @@ -26,7 +25,7 @@ except ImportError: NO_ASE_ERR = PackageNotFoundError("AseAtomsAdaptor requires the ASE package. Use `pip install ase`") - encode = decode = FixAtoms = SinglePointDFTCalculator = Spacegroup = None + encode = decode = FixAtoms = FixCartesian = SinglePointDFTCalculator = Spacegroup = None class Atoms: # type: ignore[no-redef] def __init__(self, *args, **kwargs): @@ -157,31 +156,38 @@ def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSO # Get the oxidation states from the structure oxi_states: list[float | None] = [getattr(site.specie, "oxi_state", None) for site in structure] - # Read in selective dynamics if present. Note that the ASE FixAtoms class fixes (x,y,z), so - # here we make sure that [False, False, False] or [True, True, True] is set for the site selective - # dynamics property. As a consequence, if only a subset of dimensions are fixed, this won't get passed to ASE. + # Read in selective dynamics if present. + # Note that FixCartesian class uses an opposite notion of + # "fix" and "not fix" flags: in ASE True means fixed and False + # means not fixed. + fix_atoms: dict | None = None if "selective_dynamics" in structure.site_properties: - fix_atoms = [] + fix_atoms = { + str([xc, yc, zc]): ([xc, yc, zc], []) + for xc in [True, False] + for yc in [True, False] + for zc in [True, False] + } + # [False, False, False] is free to move - no constraint in ASE. + del fix_atoms[str([False, False, False])] for site in structure: selective_dynamics: ArrayLike = site.properties.get("selective_dynamics") # type: ignore[assignment] - if ( - isinstance(selective_dynamics, Iterable) - and True in selective_dynamics - and False in selective_dynamics - ): - raise ValueError( - "ASE FixAtoms constraint does not support selective dynamics in only some dimensions. " - f"Remove the {selective_dynamics=} and try again if you do not need them." - ) - is_fixed = bool(~np.all(site.properties["selective_dynamics"])) - fix_atoms.append(is_fixed) - + for cmask_str in fix_atoms: + cmask_site = (~np.array(selective_dynamics)).tolist() + fix_atoms[cmask_str][1].append(cmask_str == str(cmask_site)) else: fix_atoms = None - # Set the selective dynamics with the FixAtoms class. + # Set the selective dynamics with the FixCartesian class. if fix_atoms is not None: - atoms.set_constraint(FixAtoms(mask=fix_atoms)) + atoms.set_constraint( + [ + FixAtoms(indices) if cmask == [True, True, True] else FixCartesian(indices, mask=cmask) + for cmask, indices in fix_atoms.values() + # Do not add empty constraints + if any(indices) + ] + ) # Add any remaining site properties to the ASE Atoms object for prop in structure.site_properties: @@ -254,21 +260,40 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs) oxi_states = atoms.get_array("oxi_states") if atoms.has("oxi_states") else None # If the ASE Atoms object has constraints, make sure that they are of the - # kind FixAtoms, which are the only ones that can be supported in Pymatgen. + # FixAtoms or FixCartesian kind, which are the only ones that + # can be supported in Pymatgen. # By default, FixAtoms fixes all three (x, y, z) dimensions. if atoms.constraints: unsupported_constraint_type = False - constraint_indices = [] + constraint_indices: dict = { + str([xc, yc, zc]): ([xc, yc, zc], []) + for xc in [True, False] + for yc in [True, False] + for zc in [True, False] + } for constraint in atoms.constraints: if isinstance(constraint, FixAtoms): - constraint_indices.extend(constraint.get_indices().tolist()) + constraint_indices[str([False] * 3)][1].extend(constraint.get_indices().tolist()) + elif isinstance(constraint, FixCartesian): + cmask = (~np.array(constraint.mask)).tolist() + constraint_indices[str(cmask)][1].extend(constraint.get_indices().tolist()) else: unsupported_constraint_type = True if unsupported_constraint_type: warnings.warn( - "Only FixAtoms is supported by Pymatgen. Other constraints will not be set.", stacklevel=2 + "Only FixAtoms and FixCartesian is supported by Pymatgen. Other constraints will not be set.", + stacklevel=2, ) - sel_dyn = [[False] * 3 if atom.index in constraint_indices else [True] * 3 for atom in atoms] + sel_dyn = [] + for atom in atoms: + constrained = False + for mask, indices in constraint_indices.values(): + if atom.index in indices: + sel_dyn.append(mask) + constrained = True + break # Assume no duplicates + if not constrained: + sel_dyn.append([False] * 3) else: sel_dyn = None diff --git a/tests/analysis/test_pourbaix_diagram.py b/tests/analysis/test_pourbaix_diagram.py index 22066499257..30150f32aa6 100644 --- a/tests/analysis/test_pourbaix_diagram.py +++ b/tests/analysis/test_pourbaix_diagram.py @@ -8,7 +8,14 @@ from monty.serialization import dumpfn, loadfn from pytest import approx -from pymatgen.analysis.pourbaix_diagram import IonEntry, MultiEntry, PourbaixDiagram, PourbaixEntry, PourbaixPlotter +from pymatgen.analysis.pourbaix_diagram import ( + IonEntry, + MultiEntry, + PourbaixDiagram, + PourbaixEntry, + PourbaixPlotter, + ion_or_solid_comp_object, +) from pymatgen.core.composition import Composition from pymatgen.core.ion import Ion from pymatgen.entries.computed_entries import ComputedEntry @@ -315,3 +322,34 @@ def test_plot_entry_stability(self): binary_plotter = PourbaixPlotter(pd_binary) ax = binary_plotter.plot_entry_stability(self.test_data["Ag-Te"][53]) assert isinstance(ax, plt.Axes) + + +class TestIonOrSolidCompObject: + def test_ion(self): + # Test cations + assert ion_or_solid_comp_object("Li+").charge == 1 + assert ion_or_solid_comp_object("Li[+]").charge == 1 + assert ion_or_solid_comp_object("Ca[2+]").charge == 2 + assert ion_or_solid_comp_object("Ca[+2]").charge == 2 + assert ion_or_solid_comp_object("Ca++").charge == 2 + assert ion_or_solid_comp_object("Ca[++]").charge == 2 + assert ion_or_solid_comp_object("Ca2+").charge == 1 + assert ion_or_solid_comp_object("C2O4-2").charge == -2 + + # Test anions + assert ion_or_solid_comp_object("Cl-").charge == -1 + assert ion_or_solid_comp_object("Cl[-]").charge == -1 + assert ion_or_solid_comp_object("SO4[-2]").charge == -2 + assert ion_or_solid_comp_object("SO4-2").charge == -2 + assert ion_or_solid_comp_object("SO42-").charge == -1 + assert ion_or_solid_comp_object("SO4--").charge == -2 + assert ion_or_solid_comp_object("SO4[--]").charge == -2 + assert ion_or_solid_comp_object("N3-").charge == -1 + + def test_solid(self): + # Test end with "(s)" + assert ion_or_solid_comp_object("Fe2O3(s)") == Composition("Fe2O3") + assert ion_or_solid_comp_object("Fe(s)") == Composition("Fe1") + + # Test end without "(s)" + assert type(ion_or_solid_comp_object("Na2O")) is Composition diff --git a/tests/io/test_ase.py b/tests/io/test_ase.py index 4459b4ea5f0..4328d0f6959 100644 --- a/tests/io/test_ase.py +++ b/tests/io/test_ase.py @@ -91,6 +91,18 @@ def test_get_atoms_from_structure_dyn(): STRUCTURE.add_site_property("selective_dynamics", [[False] * 3] * len(STRUCTURE)) atoms = AseAtomsAdaptor.get_atoms(STRUCTURE) assert atoms.constraints[0].get_indices().tolist() == [atom.index for atom in atoms] + STRUCTURE.add_site_property("selective_dynamics", [[True] * 3] * len(STRUCTURE)) + atoms = AseAtomsAdaptor.get_atoms(STRUCTURE) + assert len(atoms.constraints) == 0 + rng = np.random.default_rng(seed=1234) + sel_dyn = [[rng.random() < 0.5, rng.random() < 0.5, rng.random() < 0.5] for _ in STRUCTURE] + STRUCTURE.add_site_property("selective_dynamics", sel_dyn) + atoms = AseAtomsAdaptor.get_atoms(STRUCTURE) + for c in atoms.constraints: + # print(c) + assert isinstance(c, ase.constraints.FixAtoms | ase.constraints.FixCartesian) + ase_mask = c.mask if isinstance(c, ase.constraints.FixCartesian) else [True, True, True] + assert len(c.index) == len([mask for mask in sel_dyn if np.array_equal(mask, ~np.array(ase_mask))]) def test_get_atoms_from_molecule(): @@ -178,8 +190,10 @@ def test_get_structure_mag(): "select_dyn", [ [True, True, True], + [True, False, True], [False, False, False], np.array([True, True, True]), + np.array([True, False, True]), np.array([False, False, False]), ], )