Skip to content

Commit

Permalink
Merge branch 'main' of github.com:troyvvgroup/quemb into improve_frag…
Browse files Browse the repository at this point in the history
…mentation
  • Loading branch information
mcocdawc committed Jan 17, 2025
2 parents 2df5ea6 + 8c410c4 commit 53d7faf
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 34 deletions.
1 change: 1 addition & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[lint]
# see the rules [here](https://docs.astral.sh/ruff/rules/)
select = ["E", "F", "I", "NPY", "PL", "ARG"]
exclude = ["tests/fragmentation_test.py"]
ignore = [
"S101",
# https://docs.astral.sh/ruff/rules/assert/
Expand Down
49 changes: 28 additions & 21 deletions src/quemb/molbe/autofrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from quemb.shared.helper import unused


@define(init=True)
@define
class FragmentMap:
"""Dataclass for fragment bookkeeping.
Expand Down Expand Up @@ -61,24 +61,37 @@ class FragmentMap:
adjacency_mat: np.ndarray
adjacency_graph: nx.Graph

def remove_nonnunique_frags(self) -> None:
for adx, basa in enumerate(self.fsites):
for bdx, basb in enumerate(self.fsites):
if adx == bdx:
pass
elif set(basb).issubset(set(basa)):
tmp = set(self.center[adx] + self.center[bdx])
self.center[adx] = tuple(tmp)
del self.center[bdx]
del self.fsites[bdx]
del self.fs[bdx]
def remove_nonnunique_frags(self, natm: int) -> None:
"""Remove all fragments which are strict subsets of another.
Remove all fragments whose AO indices can be identified as subsets of
another fragment's. The center site for the removed frag is then
added to that of the superset. Because doing so will necessarily
change the definition of fragments, we repeat it up to `natm` times
such that all fragments are guaranteed to be distinct sets.
another fragment's. The center site for the removed frag is then
added to that of the superset. Because doing so will necessarily
change the definition of fragments, we repeat it up to `natm` times
such that all fragments are guaranteed to be distinct sets.
"""
for _ in range(0, natm):
for adx, basa in enumerate(self.fsites):
for bdx, basb in enumerate(self.fsites):
if adx == bdx:
pass
elif set(basb).issubset(set(basa)):
tmp = set(self.center[adx] + self.center[bdx])
self.center[adx] = tuple(tmp)
del self.center[bdx]
del self.fsites[bdx]
del self.fs[bdx]

return None


def euclidean_norm(
i_coord: float,
j_coord: float,
i_coord: np.ndarray,
j_coord: np.ndarray,
) -> np.floating[Any]:
return norm(np.asarray(i_coord - j_coord))

Expand Down Expand Up @@ -241,14 +254,8 @@ def graphgen(
else:
raise AttributeError(f"Connectivity metric not recognized: '{connectivity}'")

# Remove all fragments whose AO indices can be identified as subsets of
# another fragment's. The center site for the removed frag is then
# added to that of the superset. Because doing so will necessarily
# change the definition of fragments, we repeat it up to `natm` times
# such that all fragments are guaranteed to be distinct sets.
if remove_nonunique_frags:
for _ in range(0, natm):
fragment_map.remove_nonnunique_frags()
fragment_map.remove_nonnunique_frags(natm)

# Define the 'edges' for fragment A as the intersect of its sites
# with the set of all center sites outside of A:
Expand Down
17 changes: 4 additions & 13 deletions src/quemb/molbe/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,18 @@ def __init__(
if frozen_core:
self.ncore, self.no_core_idx, self.core_list = get_core(self.mol)

if frag_type != "hchain_simple" and self.mol is None:
raise ValueError("Provide pyscf gto.M object in fragpart() and restart!")

# Check type of fragmentation function
if frag_type == "hchain_simple":
# This is an experimental feature.
self.hchain_simple()

elif frag_type == "chain":
if mol is None:
raise ValueError(
"Provide pyscf gto.M object in fragpart() and restart!"
)
self.chain(mol, frozen_core=frozen_core, closed=closed)
self.chain(self.mol, frozen_core=frozen_core, closed=closed)

elif frag_type == "graphgen":
if self.mol is None:
raise ValueError(
"Provide pyscf gto.M object in fragpart() and restart!"
)
fragment_map = graphgen(
mol=self.mol.copy(),
be_type=be_type,
Expand All @@ -120,10 +115,6 @@ def __init__(
self.Nfrag = len(self.fsites)

elif frag_type == "autogen":
if mol is None:
raise ValueError(
"Provide pyscf gto.M object in fragpart() and restart!"
)
fgs = autogen(
mol,
be_type=be_type,
Expand Down

0 comments on commit 53d7faf

Please sign in to comment.