diff --git a/src/pymatgen/core/sites.py b/src/pymatgen/core/sites.py index 520a84e6c15..925ea8e7515 100644 --- a/src/pymatgen/core/sites.py +++ b/src/pymatgen/core/sites.py @@ -53,11 +53,11 @@ def __init__( iii.Dict of elements/species and occupancies, e.g. {"Fe" : 0.5, "Mn":0.5}. This allows the setup of disordered structures. - coords: Cartesian coordinates of site. - properties: Properties associated with the site as a dict, e.g. + coords (ArrayLike): Cartesian coordinates of site. + properties (dict): Properties associated with the site, e.g. {"magmom": 5}. Defaults to None. - label: Label for the site. Defaults to None. - skip_checks: Whether to ignore all the usual checks and just + label (str): Label for the site. Defaults to None. + skip_checks (bool): Whether to ignore all the usual checks and just create the site. Use this if the Site is created in a controlled manner and speed is desired. """ @@ -310,20 +310,20 @@ def __init__( symbols, e.g. "Li", "Fe2+", "P" or atomic numbers, e.g. 3, 56, or actual Element or Species objects. iii.Dict of elements/species and occupancies, e.g. - {"Fe" : 0.5, "Mn":0.5}. This allows the setup of + {"Fe": 0.5, "Mn": 0.5}. This allows the setup of disordered structures. - coords: Coordinates of site, fractional coordinates + coords (ArrayLike): Coordinates of site, fractional coordinates by default. See ``coords_are_cartesian`` for more details. - lattice: Lattice associated with the site. - to_unit_cell: Translates fractional coordinate to the + lattice (Lattice): Lattice associated with the site. + to_unit_cell (bool): Translates fractional coordinate to the basic unit cell, i.e. all fractional coordinates satisfy 0 <= a < 1. Defaults to False. - coords_are_cartesian: Set to True if you are providing + coords_are_cartesian (bool): Set to True if you are providing Cartesian coordinates. Defaults to False. - properties: Properties associated with the site as a dict, e.g. + properties (dict): Properties associated with the site, e.g. {"magmom": 5}. Defaults to None. - label: Label for the site. Defaults to None. - skip_checks: Whether to ignore all the usual checks and just + label (str): Label for the site. Defaults to None. + skip_checks (bool): Whether to ignore all the usual checks and just create the site. Use this if the PeriodicSite is created in a controlled manner and speed is desired. """ diff --git a/src/pymatgen/core/structure.py b/src/pymatgen/core/structure.py index 6d0cd6545ae..9cde191afa7 100644 --- a/src/pymatgen/core/structure.py +++ b/src/pymatgen/core/structure.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Sequence - from typing import Any, SupportsIndex + from typing import Any, ClassVar, SupportsIndex, TypeAlias import pandas as pd from ase import Atoms @@ -59,7 +59,7 @@ from pymatgen.util.typing import CompositionLike, MillerIndex, PathLike, PbcLike, SpeciesLike -FileFormats = Literal[ +FileFormats: TypeAlias = Literal[ "cif", "poscar", "cssr", @@ -73,7 +73,7 @@ "aims", "", ] -StructureSources = Literal["Materials Project", "COD"] +StructureSources: TypeAlias = Literal["Materials Project", "COD"] class Neighbor(Site): @@ -214,8 +214,11 @@ class SiteCollection(collections.abc.Sequence, ABC): """ # Tolerance in Angstrom for determining if sites are too close - DISTANCE_TOLERANCE = 0.5 - _properties: dict + DISTANCE_TOLERANCE: ClassVar[float] = 0.5 + + def __init__(self) -> None: + """Init a SiteCollection.""" + self._properties: dict def __contains__(self, site: object) -> bool: return site in self.sites @@ -4717,44 +4720,63 @@ def scale_lattice(self, volume: float) -> Self: return self - def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum") -> Self: - """Merges sites (adding occupancies) within tol of each other. - Removes site properties. + def merge_sites( + self, + tol: float = 0.01, + mode: Literal["sum", "delete", "average"] = "sum", + ) -> Self: + """Merges sites (by adding occupancies) within tolerance and removes + site properties in "sum/delete" modes. Args: tol (float): Tolerance for distance to merge sites. - mode ("sum" | "delete" | "average"): "delete" means duplicate sites are - deleted. "sum" means the occupancies are summed for the sites. - "average" means that the site is deleted but the properties are averaged - Only first letter is considered. + mode ("sum" | "delete" | "average"): Only first letter is considered at this moment. + - "delete": delete duplicate sites. + - "sum": sum the occupancies for the sites. + - "average": delete the site but average the properties if it's numerical. Returns: - Structure: self with merged sites. + Structure: Structure with merged sites. """ - dist_mat = self.distance_matrix + # TODO: change the code the allow full name after 2025-12-01 + # TODO2: add a test for mode value, currently it only checks if first letter is "s/a" + if mode.lower() not in {"sum", "delete", "average"} and mode.lower()[0] in {"s", "d", "a"}: + warnings.warn( + "mode would only allow full name sum/delete/average after 2025-12-01", DeprecationWarning, stacklevel=2 + ) + + if mode.lower()[0] not in {"s", "d", "a"}: + raise ValueError(f"Illegal {mode=}, should start with a/d/s.") + + dist_mat: NDArray = self.distance_matrix np.fill_diagonal(dist_mat, 0) clusters = fcluster(linkage(squareform((dist_mat + dist_mat.T) / 2)), tol, "distance") - sites = [] + + sites: list[PeriodicSite] = [] for cluster in np.unique(clusters): - inds = np.where(clusters == cluster)[0] - species = self[inds[0]].species - coords = self[inds[0]].frac_coords - props = self[inds[0]].properties - for n, i in enumerate(inds[1:]): - sp = self[i].species + indexes = np.where(clusters == cluster)[0] + species: Composition = self[indexes[0]].species + coords: NDArray = self[indexes[0]].frac_coords + props: dict = self[indexes[0]].properties + + for site_idx, clust_idx in enumerate(indexes[1:]): + # Sum occupancies in "sum" mode if mode.lower()[0] == "s": - species += sp - offset = self[i].frac_coords - coords - coords += ((offset - np.round(offset)) / (n + 2)).astype(coords.dtype) - for key in props: - if props[key] is not None and self[i].properties[key] != props[key]: - if mode.lower()[0] == "a" and isinstance(props[key], float): + species += self[clust_idx].species + + offset = self[clust_idx].frac_coords - coords + coords += ((offset - np.round(offset)) / (site_idx + 2)).astype(coords.dtype) + for key, val in props.items(): + if val is not None and not np.array_equal(self[clust_idx].properties[key], val): + if mode.lower()[0] == "a" and isinstance(val, float | int): # update a running total - props[key] = props[key] * (n + 1) / (n + 2) + self[i].properties[key] / (n + 2) + props[key] = val * (site_idx + 1) / (site_idx + 2) + self[clust_idx].properties[key] / ( + site_idx + 2 + ) else: props[key] = None warnings.warn( - f"Sites with different site property {key} are merged. So property is set to none", + f"Sites with different site property {key} are merged. But property is set to None", stacklevel=2, ) sites.append(PeriodicSite(species, coords, self.lattice, properties=props)) diff --git a/tests/core/test_structure.py b/tests/core/test_structure.py index 510d25846a5..91904667085 100644 --- a/tests/core/test_structure.py +++ b/tests/core/test_structure.py @@ -1,11 +1,11 @@ from __future__ import annotations import json +import math import os from fractions import Fraction from pathlib import Path from shutil import which -from unittest import skipIf import numpy as np import pytest @@ -29,6 +29,7 @@ from pymatgen.electronic_structure.core import Magmom from pymatgen.io.ase import AseAtomsAdaptor from pymatgen.io.cif import CifParser +from pymatgen.io.vasp.inputs import Poscar from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, PymatgenTest @@ -40,11 +41,11 @@ ase = Atoms = Calculator = EMT = None -enum_cmd = which("enum.x") or which("multienum.x") -mcsqs_cmd = which("mcsqs") +ENUM_CMD = which("enum.x") or which("multienum.x") +MCSQS_CMD = which("mcsqs") -class TestNeighbor(PymatgenTest): +class TestNeighbor: def test_msonable(self): struct = PymatgenTest.get_structure("Li2O") nn = struct.get_neighbors(struct[0], r=3) @@ -102,7 +103,7 @@ def setUp(self): ) self.V2O3 = IStructure.from_file(f"{TEST_FILES_DIR}/cif/V2O3.cif") - @skipIf(not (mcsqs_cmd and enum_cmd), reason="enumlib or mcsqs executable not present") + @pytest.mark.skipif(not (MCSQS_CMD and ENUM_CMD), reason="enumlib or mcsqs executable not present") def test_get_orderings(self): ordered = Structure.from_spacegroup("Im-3m", Lattice.cubic(3), ["Fe"], [[0, 0, 0]]) assert ordered.get_orderings()[0] == ordered @@ -1633,12 +1634,16 @@ def test_merge_sites(self): [0.5, 0.5, 1.501], ] struct = Structure(Lattice.cubic(1), species, coords) - struct.merge_sites(mode="s") + struct.merge_sites(mode="sum") assert struct[0].specie.symbol == "Ag" assert struct[1].species == Composition({"Cl": 0.35, "F": 0.25}) assert_allclose(struct[1].frac_coords, [0.5, 0.5, 0.5005]) - # Test for TaS2 with spacegroup 166 in 160 setting. + # Test illegal mode + with pytest.raises(ValueError, match="Illegal mode='illegal', should start with a/d/s"): + struct.merge_sites(mode="illegal") + + # Test for TaS2 with spacegroup 166 in 160 setting lattice = Lattice.hexagonal(3.374351, 20.308941) species = ["Ta", "S", "S"] coords = [ @@ -1646,10 +1651,10 @@ def test_merge_sites(self): [0.333333, 0.666667, 0.353424], [0.666667, 0.333333, 0.535243], ] - tas2 = Structure.from_spacegroup(160, lattice, species, coords) - assert len(tas2) == 13 - tas2.merge_sites(mode="d") - assert len(tas2) == 9 + struct_tas2 = Structure.from_spacegroup(160, lattice, species, coords) + assert len(struct_tas2) == 13 + struct_tas2.merge_sites(mode="delete") + assert len(struct_tas2) == 9 lattice = Lattice.hexagonal(3.587776, 19.622793) species = ["Na", "V", "S", "S"] @@ -1659,12 +1664,12 @@ def test_merge_sites(self): [0.333333, 0.666667, 0.399394], [0.666667, 0.333333, 0.597273], ] - navs2 = Structure.from_spacegroup(160, lattice, species, coords) - assert len(navs2) == 18 - navs2.merge_sites(mode="d") - assert len(navs2) == 12 + struct_navs2 = Structure.from_spacegroup(160, lattice, species, coords) + assert len(struct_navs2) == 18 + struct_navs2.merge_sites(mode="delete") + assert len(struct_navs2) == 12 - # Test that we can average the site properties that are floats + # Test that we can average the site properties that are numerical (float/int) lattice = Lattice.hexagonal(3.587776, 19.622793) species = ["Na", "V", "S", "S"] coords = [ @@ -1674,11 +1679,47 @@ def test_merge_sites(self): [0.666667, 0.333333, 0.597273], ] site_props = {"prop1": [3.0, 5.0, 7.0, 11.0]} - navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props) - navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0}) - navs2.merge_sites(mode="a") - assert len(navs2) == 12 - assert 51.5 in [itr.properties["prop1"] for itr in navs2] + struct_navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props) + struct_navs2.insert(0, "Na", coords[0], properties={"prop1": 100}) # int property + struct_navs2.merge_sites(mode="average") + assert len(struct_navs2) == 12 + assert any(math.isclose(site.properties["prop1"], 51.5) for site in struct_navs2) + + # Test non-numerical property warning + struct_navs2.insert(0, "Na", coords[0], properties={"prop1": "hi"}) + with pytest.warns(UserWarning, match="But property is set to None"): + struct_navs2.merge_sites(mode="average") + + # Test property handling for np.array (selective dynamics) + poscar_str_0 = """Test POSCAR +1.0 +3.840198 0.000000 0.000000 +1.920099 3.325710 0.000000 +0.000000 -2.217138 3.135509 +1 1 +Selective dynamics +direct +0.000000 0.000000 0.000000 T T T Si +0.750000 0.500000 0.750000 F F F O +""" + poscar_str_1 = """offset a bit +1.0 +3.840198 0.000000 0.000000 +1.920099 3.325710 0.000000 +0.000000 -2.217138 3.135509 +1 1 +Selective dynamics +direct +0.100000 0.000000 0.000000 T T T Si +0.750000 0.500000 0.750000 F F F O +""" + + struct_0 = Poscar.from_str(poscar_str_0).structure + struct_1 = Poscar.from_str(poscar_str_1).structure + + for site in struct_0: + struct_1.append(site.species, site.frac_coords, properties=site.properties) + struct_1.merge_sites(mode="average") def test_properties(self): assert self.struct.num_sites == len(self.struct) diff --git a/tests/io/vasp/test_sets.py b/tests/io/vasp/test_sets.py index cef111d7d7e..243f365171c 100644 --- a/tests/io/vasp/test_sets.py +++ b/tests/io/vasp/test_sets.py @@ -2,7 +2,6 @@ import hashlib import os -import unittest from glob import glob from zipfile import ZipFile @@ -1610,7 +1609,6 @@ def test_user_incar_settings(self): assert not vis.incar["LASPH"], "LASPH user setting not applied" assert vis.incar["VDW_SR"] == 1.5, "VDW_SR user setting not applied" - @unittest.skipIf(not os.path.exists(TEST_DIR), "Test files are not present.") def test_from_prev_calc(self): prev_run = os.path.join(TEST_DIR, "fixtures", "relaxation") @@ -1627,7 +1625,6 @@ def test_from_prev_calc(self): assert "VDW_A2" in vis_bj.incar assert "VDW_S8" in vis_bj.incar - @unittest.skipIf(not os.path.exists(TEST_DIR), "Test files are not present.") def test_override_from_prev_calc(self): prev_run = os.path.join(TEST_DIR, "fixtures", "relaxation") diff --git a/tests/util/test_misc.py b/tests/util/test_misc.py index 4bcf881dcad..64fe7107bb0 100644 --- a/tests/util/test_misc.py +++ b/tests/util/test_misc.py @@ -46,12 +46,6 @@ def test_nested_arrays(self): def test_diff_dtype(self): """Make sure it also works for other data types as value.""" - - @dataclass - class CustomClass: - name: str - value: int - # Test with bool values dict1 = {"a": True} dict2 = {"a": True} @@ -69,6 +63,11 @@ class CustomClass: assert not is_np_dict_equal(dict4, dict6) # Test with a custom data class + @dataclass + class CustomClass: + name: str + value: int + dict7 = {"a": CustomClass(name="test", value=1)} dict8 = {"a": CustomClass(name="test", value=1)} assert is_np_dict_equal(dict7, dict8) @@ -76,6 +75,19 @@ class CustomClass: dict9 = {"a": CustomClass(name="test", value=2)} assert not is_np_dict_equal(dict7, dict9) + # Test __eq__ method being used + @dataclass + class NewCustomClass: + name: str + value: int + + def __eq__(self, other): + return True + + dict7_1 = {"a": NewCustomClass(name="test", value=1)} + dict8_1 = {"a": NewCustomClass(name="hello", value=2)} + assert is_np_dict_equal(dict7_1, dict8_1) + # Test with None dict10 = {"a": None} dict11 = {"a": None}