Skip to content

Commit

Permalink
ruff fix coordination
Browse files Browse the repository at this point in the history
  • Loading branch information
htz1992213 committed Feb 5, 2024
1 parent 362fcb0 commit c81b502
Showing 1 changed file with 24 additions and 37 deletions.
61 changes: 24 additions & 37 deletions mdgo/coordination.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
# Copyright (c) Tingzheng Hou.
# Distributed under the terms of the MIT License.

"""
This module implements functions for coordination analysis.
"""
"""This module implements functions for coordination analysis."""

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING

import numpy as np
from MDAnalysis import AtomGroup, Universe
from MDAnalysis.analysis.distances import distance_array
from MDAnalysis.core.groups import Atom
from scipy.signal import savgol_filter
from tqdm.auto import tqdm

from mdgo.util.coord import angle, atom_vec

if TYPE_CHECKING:
from collections.abc import Callable

from MDAnalysis import AtomGroup, Universe
from MDAnalysis.core.groups import Atom

__author__ = "Tingzheng Hou"
__version__ = "0.3.0"
__maintainer__ = "Tingzheng Hou"
Expand Down Expand Up @@ -51,26 +53,22 @@ def neighbor_distance(
A dictionary of distance of neighbor atoms to the ``center_atom``. The keys are atom indexes in string type .
"""
dist_dict = {}
time_count = 0
trj_analysis = nvt_run.trajectory[run_start:run_end:]
species_selection = select_dict.get(species)
if species_selection is None:
raise ValueError("Invalid species selection")
for ts in trj_analysis:
for _ts in trj_analysis:
selection = (
"(" + species_selection + ") and (around " + str(distance) + " index " + str(center_atom.index) + ")"
)
shell = nvt_run.select_atoms(selection, periodic=True)
for atom in shell.atoms:
if str(atom.index) not in dist_dict:
dist_dict[str(atom.index)] = np.full(run_end - run_start, 100.0)
time_count += 1
time_count = 0
for ts in trj_analysis:
for time_count, ts in enumerate(trj_analysis):
for atom_index, val in dist_dict.items():
dist = distance_array(ts[center_atom.index], ts[int(atom_index)], ts.dimensions)
val[time_count] = dist
time_count += 1
return dist_dict


Expand Down Expand Up @@ -370,7 +368,7 @@ def check_contiguous_steps(
}
trj_analysis = nvt_run.trajectory[run_start:run_end:]
has = False
for i, ts in enumerate(trj_analysis):
for i, _ts in enumerate(trj_analysis):
log = False
checkpoint = -1
for j in checkpoints:
Expand Down Expand Up @@ -680,7 +678,7 @@ def cluster_coordinates( # TODO: rewrite the method
cluster = []
for atom in shell:
coord_list = []
for ts in trj_analysis:
for _ts in trj_analysis:
coord_list.append(atom.position)
cluster.append(np.mean(np.array(coord_list), axis=0))
cluster_array = np.array(cluster)
Expand All @@ -701,8 +699,7 @@ def cluster_coordinates( # TODO: rewrite the method
vec3 = vec3 / np.linalg.norm(vec3)
basis_xyz = np.transpose([vec1, vec2, vec3])
cluster_norm = np.linalg.solve(basis_xyz, cluster_array.T).T
cluster_norm = cluster_norm - np.mean(cluster_norm, axis=0)
return cluster_norm
return cluster_norm - np.mean(cluster_norm, axis=0)
return cluster_array


Expand Down Expand Up @@ -739,14 +736,13 @@ def num_of_neighbor(
A diction containing the coordination number sequence of each specified neighbor species
and the total coordination number sequence in the specified frame range .
"""
time_count = 0
trj_analysis = nvt_run.trajectory[run_start:run_end:]
cn_values = {}
species = list(distance_dict.keys())
for kw in species:
cn_values[kw] = np.zeros(int(len(trj_analysis)))
cn_values["total"] = np.zeros(int(len(trj_analysis)))
for ts in trj_analysis:
for time_count, ts in enumerate(trj_analysis):
digit_of_species = len(species) - 1
for kw in species:
selection = select_shell(select_dict, distance_dict, center_atom, kw)
Expand All @@ -772,7 +768,6 @@ def num_of_neighbor(
center_name = center_atom.name
path = write_path + str(center_atom.id) + "_" + str(int(ts.time)) + "_" + str(structure_code) + ".xyz"
write_out(center_pos, center_name, structure, path)
time_count += 1
return cn_values


Expand Down Expand Up @@ -800,13 +795,12 @@ def num_of_neighbor_simple(
A dict with "total" as the key and an array of the solvation structure type in the specified frame range
as the value.
"""
time_count = 0
trj_analysis = nvt_run.trajectory[run_start:run_end:]
center_selection = "same type as index " + str(center_atom.index)
assert len(distance_dict) == 1, "Please only specify the counter-ion species in the distance_dict"
species = list(distance_dict.keys())[0]
species = next(iter(distance_dict.keys()))
cn_values = np.zeros(int(len(trj_analysis)))
for ts in trj_analysis:
for time_count, _ts in enumerate(trj_analysis):
selection = select_shell(select_dict, distance_dict, center_atom, species)
shell = nvt_run.select_atoms(selection, periodic=True)
shell_len = len(shell)
Expand All @@ -822,9 +816,7 @@ def num_of_neighbor_simple(
cn_values[time_count] = 3
else:
cn_values[time_count] = 3
time_count += 1
cn_values = {"total": cn_values}
return cn_values
return {"total": cn_values}


def angular_dist_of_neighbor(
Expand Down Expand Up @@ -860,7 +852,7 @@ def angular_dist_of_neighbor(
neighbor_a, neighbor_b, center_c = tuple(names)
acb_angle = []
trj_analysis = nvt_run.trajectory[run_start:run_end:]
for ts in trj_analysis:
for _ts in trj_analysis:
a_selection = select_shell(select_dict, distance_dict, center_atom, neighbor_a)
a_group = nvt_run.select_atoms(a_selection, periodic=True)
a_num = len(a_group)
Expand All @@ -870,10 +862,7 @@ def angular_dist_of_neighbor(
c_selection = select_shell(select_dict, distance_dict, a_group.atoms[0], center_c)
c_atoms = nvt_run.select_atoms(c_selection, periodic=True)
shell_species_len = len(c_atoms) - 1
if shell_species_len == 0:
shell_type = "cip"
else:
shell_type = "agg"
shell_type = "cip" if shell_species_len == 0 else "agg"
else:
shell_type = "agg"
if shell_type == "agg" and cip:
Expand Down Expand Up @@ -916,7 +905,6 @@ def num_of_neighbor_specific(
A tuple containing three dictionary of the coordination number of each neighbor species
and total coordination number for the three solvation structure type, respectively.
"""
time_count = 0
trj_analysis = nvt_run.trajectory[run_start:run_end:]
cip_step = []
ssip_step = []
Expand All @@ -925,7 +913,7 @@ def num_of_neighbor_specific(
for kw in distance_dict:
cn_values[kw] = np.zeros(int(len(trj_analysis)))
cn_values["total"] = np.zeros(int(len(trj_analysis)))
for ts in trj_analysis:
for time_count, _ts in enumerate(trj_analysis):
for kw in distance_dict:
kw_selection = select_shell(select_dict, distance_dict, center_atom, kw)
kw_shell = nvt_run.select_atoms(kw_selection, periodic=True)
Expand All @@ -948,7 +936,6 @@ def num_of_neighbor_specific(
agg_step.append(time_count)
else:
agg_step.append(time_count)
time_count += 1
cn_dict = {}
for kw in distance_dict:
cn_dict["ssip_" + kw] = cn_values[kw][ssip_step]
Expand Down Expand Up @@ -989,7 +976,8 @@ def full_solvation_structure( # TODO: rewrite the method
"""
center_selection = select_dict.get(center_species)
counter_selection = select_dict.get(counter_species)
assert (center_selection is not None) and (counter_selection is not None)
assert center_selection is not None
assert counter_selection is not None

def select_counter_ion(selection, dist, atom):
return "(" + selection + " and around " + str(dist) + " same fragment as index " + str(atom.index) + ")"
Expand Down Expand Up @@ -1021,7 +1009,7 @@ def counter_shell(this_shell, this_layer, frame):
time_count = 0
trj_analysis = nvt_run.trajectory[run_start:run_end:]
cn_values = np.zeros((int(len(trj_analysis)), depth))
for ts in trj_analysis:
for _ts in trj_analysis:
center_ion_list: list[np.int_] = [center_atom.id]
counter_ion_list: list[np.int_] = []
first_shell = nvt_run.select_atoms(
Expand Down Expand Up @@ -1132,5 +1120,4 @@ def select_shell(
distance_str = str(distance_value)
else:
distance_str = distance
selection = "(" + species_selection + ") and (around " + distance_str + " index " + str(center_atom.index) + ")"
return selection
return "(" + species_selection + ") and (around " + distance_str + " index " + str(center_atom.index) + ")"

0 comments on commit c81b502

Please sign in to comment.