Skip to content

Commit

Permalink
Fix array comparison in core.Structure.merge_sites, also allow `int…
Browse files Browse the repository at this point in the history
…` property to be merged instead of `float` alone, `mode` only allow full name (#4198)

* oops wrong branch

* test __eq__

* fix typo

* add test for illegal mode

* minor code clean up

* add test

* add test for non-numerical warning

* enhance dostring

* more descriptive var name

* use pytest skipif mark

* mark classvar

* remove skip mark altogether, no reason to check test file

* add type alias type

* NEED CONFIRM: make properties instance var

* ruff fix
  • Loading branch information
DanielYang59 authored Jan 9, 2025
1 parent 9918e89 commit 361106f
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 71 deletions.
24 changes: 12 additions & 12 deletions src/pymatgen/core/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def __init__(
iii.Dict of elements/species and occupancies, e.g.
{"Fe" : 0.5, "Mn":0.5}. This allows the setup of
disordered structures.
coords: Cartesian coordinates of site.
properties: Properties associated with the site as a dict, e.g.
coords (ArrayLike): Cartesian coordinates of site.
properties (dict): Properties associated with the site, e.g.
{"magmom": 5}. Defaults to None.
label: Label for the site. Defaults to None.
skip_checks: Whether to ignore all the usual checks and just
label (str): Label for the site. Defaults to None.
skip_checks (bool): Whether to ignore all the usual checks and just
create the site. Use this if the Site is created in a controlled
manner and speed is desired.
"""
Expand Down Expand Up @@ -310,20 +310,20 @@ def __init__(
symbols, e.g. "Li", "Fe2+", "P" or atomic numbers,
e.g. 3, 56, or actual Element or Species objects.
iii.Dict of elements/species and occupancies, e.g.
{"Fe" : 0.5, "Mn":0.5}. This allows the setup of
{"Fe": 0.5, "Mn": 0.5}. This allows the setup of
disordered structures.
coords: Coordinates of site, fractional coordinates
coords (ArrayLike): Coordinates of site, fractional coordinates
by default. See ``coords_are_cartesian`` for more details.
lattice: Lattice associated with the site.
to_unit_cell: Translates fractional coordinate to the
lattice (Lattice): Lattice associated with the site.
to_unit_cell (bool): Translates fractional coordinate to the
basic unit cell, i.e. all fractional coordinates satisfy 0
<= a < 1. Defaults to False.
coords_are_cartesian: Set to True if you are providing
coords_are_cartesian (bool): Set to True if you are providing
Cartesian coordinates. Defaults to False.
properties: Properties associated with the site as a dict, e.g.
properties (dict): Properties associated with the site, e.g.
{"magmom": 5}. Defaults to None.
label: Label for the site. Defaults to None.
skip_checks: Whether to ignore all the usual checks and just
label (str): Label for the site. Defaults to None.
skip_checks (bool): Whether to ignore all the usual checks and just
create the site. Use this if the PeriodicSite is created in a
controlled manner and speed is desired.
"""
Expand Down
80 changes: 51 additions & 29 deletions src/pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator, Sequence
from typing import Any, SupportsIndex
from typing import Any, ClassVar, SupportsIndex, TypeAlias

import pandas as pd
from ase import Atoms
Expand All @@ -59,7 +59,7 @@

from pymatgen.util.typing import CompositionLike, MillerIndex, PathLike, PbcLike, SpeciesLike

FileFormats = Literal[
FileFormats: TypeAlias = Literal[
"cif",
"poscar",
"cssr",
Expand All @@ -73,7 +73,7 @@
"aims",
"",
]
StructureSources = Literal["Materials Project", "COD"]
StructureSources: TypeAlias = Literal["Materials Project", "COD"]


class Neighbor(Site):
Expand Down Expand Up @@ -214,8 +214,11 @@ class SiteCollection(collections.abc.Sequence, ABC):
"""

# Tolerance in Angstrom for determining if sites are too close
DISTANCE_TOLERANCE = 0.5
_properties: dict
DISTANCE_TOLERANCE: ClassVar[float] = 0.5

def __init__(self) -> None:
"""Init a SiteCollection."""
self._properties: dict

def __contains__(self, site: object) -> bool:
return site in self.sites
Expand Down Expand Up @@ -4717,44 +4720,63 @@ def scale_lattice(self, volume: float) -> Self:

return self

def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum") -> Self:
"""Merges sites (adding occupancies) within tol of each other.
Removes site properties.
def merge_sites(
self,
tol: float = 0.01,
mode: Literal["sum", "delete", "average"] = "sum",
) -> Self:
"""Merges sites (by adding occupancies) within tolerance and removes
site properties in "sum/delete" modes.
Args:
tol (float): Tolerance for distance to merge sites.
mode ("sum" | "delete" | "average"): "delete" means duplicate sites are
deleted. "sum" means the occupancies are summed for the sites.
"average" means that the site is deleted but the properties are averaged
Only first letter is considered.
mode ("sum" | "delete" | "average"): Only first letter is considered at this moment.
- "delete": delete duplicate sites.
- "sum": sum the occupancies for the sites.
- "average": delete the site but average the properties if it's numerical.
Returns:
Structure: self with merged sites.
Structure: Structure with merged sites.
"""
dist_mat = self.distance_matrix
# TODO: change the code the allow full name after 2025-12-01
# TODO2: add a test for mode value, currently it only checks if first letter is "s/a"
if mode.lower() not in {"sum", "delete", "average"} and mode.lower()[0] in {"s", "d", "a"}:
warnings.warn(
"mode would only allow full name sum/delete/average after 2025-12-01", DeprecationWarning, stacklevel=2
)

if mode.lower()[0] not in {"s", "d", "a"}:
raise ValueError(f"Illegal {mode=}, should start with a/d/s.")

dist_mat: NDArray = self.distance_matrix
np.fill_diagonal(dist_mat, 0)
clusters = fcluster(linkage(squareform((dist_mat + dist_mat.T) / 2)), tol, "distance")
sites = []

sites: list[PeriodicSite] = []
for cluster in np.unique(clusters):
inds = np.where(clusters == cluster)[0]
species = self[inds[0]].species
coords = self[inds[0]].frac_coords
props = self[inds[0]].properties
for n, i in enumerate(inds[1:]):
sp = self[i].species
indexes = np.where(clusters == cluster)[0]
species: Composition = self[indexes[0]].species
coords: NDArray = self[indexes[0]].frac_coords
props: dict = self[indexes[0]].properties

for site_idx, clust_idx in enumerate(indexes[1:]):
# Sum occupancies in "sum" mode
if mode.lower()[0] == "s":
species += sp
offset = self[i].frac_coords - coords
coords += ((offset - np.round(offset)) / (n + 2)).astype(coords.dtype)
for key in props:
if props[key] is not None and self[i].properties[key] != props[key]:
if mode.lower()[0] == "a" and isinstance(props[key], float):
species += self[clust_idx].species

offset = self[clust_idx].frac_coords - coords
coords += ((offset - np.round(offset)) / (site_idx + 2)).astype(coords.dtype)
for key, val in props.items():
if val is not None and not np.array_equal(self[clust_idx].properties[key], val):
if mode.lower()[0] == "a" and isinstance(val, float | int):
# update a running total
props[key] = props[key] * (n + 1) / (n + 2) + self[i].properties[key] / (n + 2)
props[key] = val * (site_idx + 1) / (site_idx + 2) + self[clust_idx].properties[key] / (
site_idx + 2
)
else:
props[key] = None
warnings.warn(
f"Sites with different site property {key} are merged. So property is set to none",
f"Sites with different site property {key} are merged. But property is set to None",
stacklevel=2,
)
sites.append(PeriodicSite(species, coords, self.lattice, properties=props))
Expand Down
83 changes: 62 additions & 21 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import json
import math
import os
from fractions import Fraction
from pathlib import Path
from shutil import which
from unittest import skipIf

import numpy as np
import pytest
Expand All @@ -29,6 +29,7 @@
from pymatgen.electronic_structure.core import Magmom
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.io.cif import CifParser
from pymatgen.io.vasp.inputs import Poscar
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, PymatgenTest

Expand All @@ -40,11 +41,11 @@
ase = Atoms = Calculator = EMT = None


enum_cmd = which("enum.x") or which("multienum.x")
mcsqs_cmd = which("mcsqs")
ENUM_CMD = which("enum.x") or which("multienum.x")
MCSQS_CMD = which("mcsqs")


class TestNeighbor(PymatgenTest):
class TestNeighbor:
def test_msonable(self):
struct = PymatgenTest.get_structure("Li2O")
nn = struct.get_neighbors(struct[0], r=3)
Expand Down Expand Up @@ -102,7 +103,7 @@ def setUp(self):
)
self.V2O3 = IStructure.from_file(f"{TEST_FILES_DIR}/cif/V2O3.cif")

@skipIf(not (mcsqs_cmd and enum_cmd), reason="enumlib or mcsqs executable not present")
@pytest.mark.skipif(not (MCSQS_CMD and ENUM_CMD), reason="enumlib or mcsqs executable not present")
def test_get_orderings(self):
ordered = Structure.from_spacegroup("Im-3m", Lattice.cubic(3), ["Fe"], [[0, 0, 0]])
assert ordered.get_orderings()[0] == ordered
Expand Down Expand Up @@ -1633,23 +1634,27 @@ def test_merge_sites(self):
[0.5, 0.5, 1.501],
]
struct = Structure(Lattice.cubic(1), species, coords)
struct.merge_sites(mode="s")
struct.merge_sites(mode="sum")
assert struct[0].specie.symbol == "Ag"
assert struct[1].species == Composition({"Cl": 0.35, "F": 0.25})
assert_allclose(struct[1].frac_coords, [0.5, 0.5, 0.5005])

# Test for TaS2 with spacegroup 166 in 160 setting.
# Test illegal mode
with pytest.raises(ValueError, match="Illegal mode='illegal', should start with a/d/s"):
struct.merge_sites(mode="illegal")

# Test for TaS2 with spacegroup 166 in 160 setting
lattice = Lattice.hexagonal(3.374351, 20.308941)
species = ["Ta", "S", "S"]
coords = [
[0, 0, 0.944333],
[0.333333, 0.666667, 0.353424],
[0.666667, 0.333333, 0.535243],
]
tas2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(tas2) == 13
tas2.merge_sites(mode="d")
assert len(tas2) == 9
struct_tas2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(struct_tas2) == 13
struct_tas2.merge_sites(mode="delete")
assert len(struct_tas2) == 9

lattice = Lattice.hexagonal(3.587776, 19.622793)
species = ["Na", "V", "S", "S"]
Expand All @@ -1659,12 +1664,12 @@ def test_merge_sites(self):
[0.333333, 0.666667, 0.399394],
[0.666667, 0.333333, 0.597273],
]
navs2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(navs2) == 18
navs2.merge_sites(mode="d")
assert len(navs2) == 12
struct_navs2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(struct_navs2) == 18
struct_navs2.merge_sites(mode="delete")
assert len(struct_navs2) == 12

# Test that we can average the site properties that are floats
# Test that we can average the site properties that are numerical (float/int)
lattice = Lattice.hexagonal(3.587776, 19.622793)
species = ["Na", "V", "S", "S"]
coords = [
Expand All @@ -1674,11 +1679,47 @@ def test_merge_sites(self):
[0.666667, 0.333333, 0.597273],
]
site_props = {"prop1": [3.0, 5.0, 7.0, 11.0]}
navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props)
navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0})
navs2.merge_sites(mode="a")
assert len(navs2) == 12
assert 51.5 in [itr.properties["prop1"] for itr in navs2]
struct_navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props)
struct_navs2.insert(0, "Na", coords[0], properties={"prop1": 100}) # int property
struct_navs2.merge_sites(mode="average")
assert len(struct_navs2) == 12
assert any(math.isclose(site.properties["prop1"], 51.5) for site in struct_navs2)

# Test non-numerical property warning
struct_navs2.insert(0, "Na", coords[0], properties={"prop1": "hi"})
with pytest.warns(UserWarning, match="But property is set to None"):
struct_navs2.merge_sites(mode="average")

# Test property handling for np.array (selective dynamics)
poscar_str_0 = """Test POSCAR
1.0
3.840198 0.000000 0.000000
1.920099 3.325710 0.000000
0.000000 -2.217138 3.135509
1 1
Selective dynamics
direct
0.000000 0.000000 0.000000 T T T Si
0.750000 0.500000 0.750000 F F F O
"""
poscar_str_1 = """offset a bit
1.0
3.840198 0.000000 0.000000
1.920099 3.325710 0.000000
0.000000 -2.217138 3.135509
1 1
Selective dynamics
direct
0.100000 0.000000 0.000000 T T T Si
0.750000 0.500000 0.750000 F F F O
"""

struct_0 = Poscar.from_str(poscar_str_0).structure
struct_1 = Poscar.from_str(poscar_str_1).structure

for site in struct_0:
struct_1.append(site.species, site.frac_coords, properties=site.properties)
struct_1.merge_sites(mode="average")

def test_properties(self):
assert self.struct.num_sites == len(self.struct)
Expand Down
3 changes: 0 additions & 3 deletions tests/io/vasp/test_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import hashlib
import os
import unittest
from glob import glob
from zipfile import ZipFile

Expand Down Expand Up @@ -1610,7 +1609,6 @@ def test_user_incar_settings(self):
assert not vis.incar["LASPH"], "LASPH user setting not applied"
assert vis.incar["VDW_SR"] == 1.5, "VDW_SR user setting not applied"

@unittest.skipIf(not os.path.exists(TEST_DIR), "Test files are not present.")
def test_from_prev_calc(self):
prev_run = os.path.join(TEST_DIR, "fixtures", "relaxation")

Expand All @@ -1627,7 +1625,6 @@ def test_from_prev_calc(self):
assert "VDW_A2" in vis_bj.incar
assert "VDW_S8" in vis_bj.incar

@unittest.skipIf(not os.path.exists(TEST_DIR), "Test files are not present.")
def test_override_from_prev_calc(self):
prev_run = os.path.join(TEST_DIR, "fixtures", "relaxation")

Expand Down
24 changes: 18 additions & 6 deletions tests/util/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ def test_nested_arrays(self):

def test_diff_dtype(self):
"""Make sure it also works for other data types as value."""

@dataclass
class CustomClass:
name: str
value: int

# Test with bool values
dict1 = {"a": True}
dict2 = {"a": True}
Expand All @@ -69,13 +63,31 @@ class CustomClass:
assert not is_np_dict_equal(dict4, dict6)

# Test with a custom data class
@dataclass
class CustomClass:
name: str
value: int

dict7 = {"a": CustomClass(name="test", value=1)}
dict8 = {"a": CustomClass(name="test", value=1)}
assert is_np_dict_equal(dict7, dict8)

dict9 = {"a": CustomClass(name="test", value=2)}
assert not is_np_dict_equal(dict7, dict9)

# Test __eq__ method being used
@dataclass
class NewCustomClass:
name: str
value: int

def __eq__(self, other):
return True

dict7_1 = {"a": NewCustomClass(name="test", value=1)}
dict8_1 = {"a": NewCustomClass(name="hello", value=2)}
assert is_np_dict_equal(dict7_1, dict8_1)

# Test with None
dict10 = {"a": None}
dict11 = {"a": None}
Expand Down

0 comments on commit 361106f

Please sign in to comment.