-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
topology featurizers debug #460
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,7 @@ | ||
version = 1 | ||
|
||
test_patterns = ["tests/**"] | ||
|
||
exclude_patterns = [ | ||
"docs/", | ||
"dev/", | ||
"binder/" | ||
] | ||
|
||
[[analyzers]] | ||
name = "python" | ||
enabled = true | ||
|
||
[analyzers.meta] | ||
runtime_version = "3.x.x" | ||
runtime_version = "3.x.x" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
from mofdscribe.featurizers.utils import flat | ||
from mofdscribe.featurizers.utils.aggregators import MA_ARRAY_AGGREGATORS | ||
from mofdscribe.featurizers.utils.substructures import filter_element | ||
from mofdscribe.featurizers.utils.substructures import filter_element_for_ph | ||
|
||
|
||
# @np_cache | ||
|
@@ -104,8 +104,8 @@ def make_supercell( | |
xyz_periodic_copies = [] | ||
element_copies = [] | ||
|
||
# xyz_periodic_copies.append(coords) | ||
# element_copies.append(np.array(elements).reshape(-1,1)) | ||
xyz_periodic_copies.append(coords) | ||
element_copies.append(np.array(elements).reshape(-1,1)) | ||
Comment on lines
+107
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, seems like a good catch, because the |
||
min_range = -3 # we aren't going in the minimum direction too much, so can make this small | ||
max_range = 20 # make this large enough, but can modify if wanting an even larger cell | ||
|
||
|
@@ -228,7 +228,7 @@ def get_images( | |
# ToDo: only do this for all if we want | ||
def get_persistent_images_for_structure( | ||
structure: Structure, | ||
elements: List[List[str]], | ||
elements: List[str], | ||
compute_for_all_elements: bool = True, | ||
min_size: int = 20, | ||
spread: float = 0.2, | ||
|
@@ -245,7 +245,7 @@ def get_persistent_images_for_structure( | |
|
||
Args: | ||
structure (Structure): input structure | ||
elements (List[List[str]]): list of elements to compute for | ||
elements (List[str]): list of element groups to compute for | ||
compute_for_all_elements (bool): compute for all elements | ||
min_size (int): minimum size of the cell for construction of persistent images | ||
spread (float): spread of kernel for construction | ||
|
@@ -273,9 +273,9 @@ def get_persistent_images_for_structure( | |
specs = [] | ||
for mb, mp in zip(max_b, max_p): | ||
specs.append({"minBD": 0, "maxB": mb, "maxP": mp}) | ||
for element in elements: | ||
for elements_group in elements: | ||
try: | ||
filtered_structure = filter_element(structure, element) | ||
filtered_structure = filter_element_for_ph(structure, elements_group) | ||
coords, _weights, _elements = _coords_for_structure( | ||
filtered_structure, | ||
min_size=min_size, | ||
|
@@ -294,7 +294,7 @@ def get_persistent_images_for_structure( | |
dimensions=(0, 1, 2), | ||
) | ||
except Exception: | ||
logger.exception(f"Error computing persistent images for {element}") | ||
logger.exception(f"Error computing persistent images for {elements_group}") | ||
images = {} | ||
for dim in [0, 1, 2]: | ||
im = np.zeros((pixels[0], pixels[1])) | ||
|
@@ -304,8 +304,8 @@ def get_persistent_images_for_structure( | |
persistent_dia[:] = np.nan | ||
|
||
# ToDo: make sure that we have the correct length | ||
element_images["image"][element] = images | ||
element_images["array"][element] = persistent_dia | ||
element_images["image"][elements_group] = images | ||
element_images["array"][elements_group] = persistent_dia | ||
|
||
if compute_for_all_elements: | ||
try: | ||
|
@@ -391,7 +391,7 @@ def get_diagrams_for_structure( | |
nan_array[:] = np.nan | ||
for element in elements: | ||
try: | ||
filtered_structure = filter_element(structure, element) | ||
filtered_structure = filter_element_for_ph(structure, element) | ||
coords, weights, _elements = _coords_for_structure( | ||
filtered_structure, | ||
min_size=min_size, | ||
|
@@ -442,7 +442,7 @@ def get_persistence_image_limits_for_structure( | |
limits = defaultdict(list) | ||
for element in elements: | ||
try: | ||
filtered_structure = filter_element(structure, element) | ||
filtered_structure = filter_element_for_ph(structure, element) | ||
|
||
coords, weights, _elements = _coords_for_structure( | ||
filtered_structure, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,14 +34,14 @@ class PHHist(MOFBaseFeaturizer): | |
|
||
def __init__( | ||
self, | ||
atom_types: Tuple[str] = ( | ||
atom_types: List[str] = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tuple was originally used because it is immutable and hence a better default for arguments. But since we do not change the |
||
"C-H-N-O", | ||
"F-Cl-Br-I", | ||
"Cu-Mn-Ni-Mo-Fe-Pt-Zn-Ca-Er-Au-Cd-Co-Gd-Na-Sm-Eu-Tb-V" | ||
"-Ag-Nd-U-Ba-Ce-K-Ga-Cr-Al-Li-Sc-Ru-In-Mg-Zr-Dy-W-Yb-Y-" | ||
"Ho-Re-Be-Rb-La-Sn-Cs-Pb-Pr-Bi-Tm-Sr-Ti-Hf-Ir-Nb-Pd-Hg-" | ||
"Th-Np-Lu-Rh-Pu", | ||
), | ||
], | ||
compute_for_all_elements: bool = True, | ||
dimensions: Tuple[int] = (1, 2), | ||
min_size: int = 20, | ||
|
@@ -57,12 +57,12 @@ def __init__( | |
"""Initialize the PHStats object. | ||
|
||
Args: | ||
atom_types (tuple): Atoms that are used to create substructures | ||
atom_types (list): Atoms that are used to create substructures | ||
for which the persistent homology statistics are computed. | ||
Defaults to ( "C-H-N-O", "F-Cl-Br-I", | ||
Defaults to [ "C-H-N-O", "F-Cl-Br-I", | ||
"Cu-Mn-Ni-Mo-Fe-Pt-Zn-Ca-Er-Au-Cd-Co-Gd-Na-Sm-Eu-Tb-V-Ag-Nd-U-Ba-Ce-K-Ga- | ||
Cr-Al-Li-Sc-Ru-In-Mg-Zr-Dy-W-Yb-Y-Ho-Re-Be-Rb-La-Sn-Cs-Pb-Pr-Bi-Tm-Sr-Ti- | ||
Hf-Ir-Nb-Pd-Hg-Th-Np-Lu-Rh-Pu", ). | ||
Hf-Ir-Nb-Pd-Hg-Th-Np-Lu-Rh-Pu", ]. | ||
compute_for_all_elements (bool): Compute descriptor for original structure with all atoms. | ||
Defaults to True. | ||
dimensions (Tuple[int]): Dimensions of topological features to consider. | ||
|
@@ -89,9 +89,10 @@ def __init__( | |
""" | ||
atom_types = [] if atom_types is None else atom_types | ||
self.elements = atom_types | ||
self.atom_types = ( | ||
list(atom_types) + ["all"] if compute_for_all_elements else list(atom_types) | ||
) | ||
if compute_for_all_elements: | ||
self.atom_types = atom_types + ["all"] | ||
else: | ||
self.atom_types = atom_types | ||
self.compute_for_all_elements = compute_for_all_elements | ||
self.dimensions = dimensions | ||
self.min_size = min_size | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,39 @@ def filter_element( | |
return Structure.from_sites(keep_sites) | ||
else: # input is molecule or IMolecule | ||
return Molecule.from_sites(keep_sites) | ||
|
||
|
||
def filter_element_for_ph( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (complexity): Consider refactoring to consolidate logic into a single function. The new code introduces unnecessary complexity by duplicating logic in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you make a suggestion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Certainly! To address the complexity issue, you could refactor the
Here's a rough example of how you might structure the refactoring: def filter_elements_common(sites, elements, elements_group):
return [site for site in sites if site.specie.symbol in elements or site.specie.symbol in elements_group]
def filter_element(structure, elements, elements_group=()):
keep_sites = filter_elements_common(structure.sites, elements, elements_group)
return Structure.from_sites(keep_sites) if isinstance(structure, Structure) else Molecule.from_sites(keep_sites)
def filter_element_for_ph(structure, elements, elements_group=()):
# If there's any specific logic for persistence homology, handle it here
keep_sites = filter_elements_common(structure.sites, elements, elements_group)
return Structure.from_sites(keep_sites) if isinstance(structure, Structure) else Molecule.from_sites(keep_sites) This approach keeps the core filtering logic in one place, making it easier to update and maintain. If |
||
structure: Union[Structure, IStructure, Molecule, IMolecule], elements: str | ||
) -> Structure: | ||
"""Filter a structure by element. | ||
|
||
Args: | ||
structure (Union[Structure, IStructure, Molecule, IMolecule]): input structure | ||
elements (str): element to filter | ||
|
||
Returns: | ||
filtered_structure (Structure): filtered structure | ||
""" | ||
elements_ = [] | ||
elements_group = (elements,) | ||
for atom_type in elements_group: | ||
Comment on lines
+52
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not sure I understand this change. You are certainly right that the old code behaves in an unexpected way for elements with more than two symbols in the letter. But wouldn't then the clearer fix be something else? For example, always split by |
||
if "-" in atom_type: | ||
elements_.extend(atom_type.split("-")) | ||
else: | ||
elements_.append(atom_type) | ||
keep_sites = [] | ||
for site in structure.sites: | ||
if site.specie.symbol in elements_: | ||
keep_sites.append(site) | ||
if len(keep_sites) == 0: | ||
return None | ||
|
||
input_is_structure = isinstance(structure, (Structure, IStructure)) | ||
if input_is_structure: | ||
return Structure.from_sites(keep_sites) | ||
else: # input is molecule or IMolecule | ||
return Molecule.from_sites(keep_sites) | ||
|
||
|
||
def elements_in_structure(structure: Structure) -> List[str]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that deepsource is not heavily used at the moment, but why are those changes needed for this PR?