Skip to content

Commit

Permalink
Merge pull request #596 from JiQi535/master
Browse files Browse the repository at this point in the history
M3GNetSite describer; DIRECT samples smallest cells possible
  • Loading branch information
shyuep authored Mar 20, 2024
2 parents eb4a8ea + 1463eee commit 2c1bf6a
Show file tree
Hide file tree
Showing 34 changed files with 579 additions and 208 deletions.
9 changes: 5 additions & 4 deletions maml/apps/bowsr/model/megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def __init__(self, model=None, reconstruct=False, **kwargs):
"""
Args:
model: MEGNet energy model
model: MEGNet energy model.
reconstruct: Whether to reconstruct the model (used in
disordered model)
**kwargs:
disordered model).
**kwargs: "gaussian_cutoff", "radius_cutoff", and "npass".
"""
model = model or MEGNetModel.from_file(model_filename)
gaussian_cutoff = kwargs.get("gaussian_cutoff", 6)
Expand All @@ -53,7 +53,8 @@ def __init__(self, model=None, reconstruct=False, **kwargs):
self.embedding = weights[0]
if reconstruct:
cg = CrystalGraph(
bond_converter=GaussianDistance(np.linspace(0, gaussian_cutoff, 100), 0.5), cutoff=radius_cutoff
bond_converter=GaussianDistance(np.linspace(0, gaussian_cutoff, 100), 0.5),
cutoff=radius_cutoff,
)
model_new = MEGNetModel(100, 2, 16, npass=npass, graph_converter=cg)
model_new.set_weights(weights[1:])
Expand Down
12 changes: 8 additions & 4 deletions maml/apps/bowsr/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def perturbation_mapping(x, fixed_indices):
Perturbation mapping.
Args:
x:
fixed_indices:
x: Perturbation mapping for unfixed lattice.
fixed_indices: indices to fix.
Returns:
Returns: Perturbation mapping.
"""
return np.array(
Expand All @@ -58,7 +58,11 @@ class WyckoffPerturbation:
"""

def __init__(
self, int_symbol: int, wyckoff_symbol: str, symmetry_ops: list[SymmOp] | None = None, use_symmetry: bool = True
self,
int_symbol: int,
wyckoff_symbol: str,
symmetry_ops: list[SymmOp] | None = None,
use_symmetry: bool = True,
):
"""
Args:
Expand Down
19 changes: 14 additions & 5 deletions maml/apps/bowsr/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ class StandardScaler(MSONable):
dictionary representation.
"""

def __init__(self, mean: list | np.ndarray | None = None, std: list | np.ndarray | None = None):
def __init__(
self,
mean: list | np.ndarray | None = None,
std: list | np.ndarray | None = None,
):
"""
Args:
mean: np.ndarray, mean values
Expand Down Expand Up @@ -119,10 +123,15 @@ def __repr__(self):

def as_dict(self):
"""
Serialize the instance into dictionary
Returns:
Serialize the instance into dictionary.
Returns: dict.
"""
return {"@module": self.__class__.__module__, "@class": self.__class__.__name__, "params": {}}
return {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
"params": {},
}

@classmethod
def from_dict(cls, d):
Expand All @@ -131,7 +140,7 @@ def from_dict(cls, d):
Args:
d: Dict, dictionary contain class initialization parameters.
Returns:
Returns: DummyScaler.
"""
return cls()
44 changes: 37 additions & 7 deletions maml/apps/gbe/describer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ class GBDescriber(BaseDescriber):
with selected structural and elemental features.
"""

def __init__(self, structural_features: list | None = None, elemental_features: list | None = None, **kwargs):
def __init__(
self,
structural_features: list | None = None,
elemental_features: list | None = None,
**kwargs,
):
"""
Args:
Expand All @@ -99,15 +104,31 @@ def __init__(self, structural_features: list | None = None, elemental_features:
**kwargs (dict): parameters for BaseDescriber.
"""
if not elemental_features:
elemental_features = [preset.e_coh, preset.G, preset.a0, preset.ar, preset.mean_delta_bl, preset.mean_bl]
elemental_features = [
preset.e_coh,
preset.G,
preset.a0,
preset.ar,
preset.mean_delta_bl,
preset.mean_bl,
]
if not structural_features:
structural_features = [preset.d_gb, preset.d_rot, preset.sin_theta, preset.cos_theta]
structural_features = [
preset.d_gb,
preset.d_rot,
preset.sin_theta,
preset.cos_theta,
]
self.elem_features = elemental_features
self.struc_features = structural_features
super().__init__(**kwargs)

def transform_one(
self, db_entry: dict, inc_target: bool = True, inc_bulk_ref: bool = True, mp_api: str | None = None
self,
db_entry: dict,
inc_target: bool = True,
inc_bulk_ref: bool = True,
mp_api: str | None = None,
) -> pd.DataFrame:
"""
Describe gb with selected structural and elemental features
Expand Down Expand Up @@ -202,7 +223,10 @@ def get_structural_feature(db_entry: dict, features: list | None = None) -> pd.D


def get_elemental_feature(
db_entry: dict, loc_algo: str = "crystalnn", features: list | None = None, mp_api: str | None = None
db_entry: dict,
loc_algo: str = "crystalnn",
features: list | None = None,
mp_api: str | None = None,
) -> pd.DataFrame:
"""
Function to get the elemental features
Expand Down Expand Up @@ -303,7 +327,13 @@ class GBBond(MSONable):
]
}

def __init__(self, gb: GrainBoundary, loc_algo: str = "crystalnn", bond_mat: np.ndarray | None = None, **kwargs):
def __init__(
self,
gb: GrainBoundary,
loc_algo: str = "crystalnn",
bond_mat: np.ndarray | None = None,
**kwargs,
):
"""
Args:
Expand All @@ -312,7 +342,7 @@ def __init__(self, gb: GrainBoundary, loc_algo: str = "crystalnn", bond_mat: np.
See options: GBBond.NNDict.keys()
Default: crystalnn
bond_mat (np.ndarray): optional.
**kwargs: pass to loc_algo.
"""
self.loc_algo = self.NNDict[loc_algo](**kwargs)
self.gb = gb
Expand Down
74 changes: 56 additions & 18 deletions maml/apps/pes/_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def feed(attribute, kwargs, dictionary, tab="\t"):
(str).
"""
tmp = kwargs.get(attribute) if kwargs.get(attribute) else dictionary.get(attribute).get("value")
return tab + dictionary.get(attribute).get("name"), str(tmp), dictionary.get(attribute).get("comment")
return (
tab + dictionary.get(attribute).get("name"),
str(tmp),
dictionary.get(attribute).get("comment"),
)


class MTPotential(LammpsPotential):
Expand Down Expand Up @@ -100,10 +104,26 @@ def _line_up(self, structure, energy, forces, virial_stress):
format_str = "{:>14s}{:>5s}{:>15s}{:>14s}{:>14s}{:>13s}{:>13s}{:>13s}"
format_float = "{:>14d}{:>5d}{:>15f}{:>14f}{:>14f}{:>13f}{:>13f}{:>13f}"
lines.append(
format_str.format("AtomData: id", "type", "cartes_x", "cartes_y", "cartes_z", "fx", "fy", "fz")
format_str.format(
"AtomData: id",
"type",
"cartes_x",
"cartes_y",
"cartes_z",
"fx",
"fy",
"fz",
)
)
for i, (site, force) in enumerate(zip(structure, forces)):
lines.append(format_float.format(i + 1, self.elements.index(str(site.specie)), *site.coords, *force))
lines.append(
format_float.format(
i + 1,
self.elements.index(str(site.specie)),
*site.coords,
*force,
)
)
if "Energy" in inputs:
lines.append(" Energy")
lines.append(f"{inputs['Energy']:>24.12f}")
Expand All @@ -128,7 +148,7 @@ def write_cfg(self, filename, cfg_pool):
filename (str): filename
cfg_pool (list): list of configurations.
Returns:
Returns: filename.
"""
if not self.elements:
Expand All @@ -151,16 +171,18 @@ def write_cfg(self, filename, cfg_pool):

return filename

def write_ini(self, mtp_filename="fitted.mtp", select=False, **kwargs):
def write_ini(self, mtp_filename="fitted.mtp", **kwargs):
"""
Write mlip.ini file for mlip packages of version mlip-2 or mlip-dev.
Supported keyword arguments are parallel with options stated in the mlip manuals.
mlip-2 is recommended, as it is the only officially supported version by mlip.
Please refer to https://mlip.skoltech.ru.
Args:
mtp_filename (str): Name of file with MTP to be loaded.
**kwargs: Different kwargs for mlip-2 and mlip-dev.
mlip-2:
mtp_filename (str): Name of file with MTP to be loaded.
write_cfgs (str): Name of file for mlp processed configurations to be written to.
write_cfgs_skip (int): Skipped number of processed configurations before writing.
select (bool): activates or deactivates calculation of extrapolation grades and
Expand Down Expand Up @@ -300,9 +322,9 @@ def write_ini(self, mtp_filename="fitted.mtp", select=False, **kwargs):
lattice vectors in Cartesian coordinates (in Angstroms).
Default to 1.0e-8.
BFGS_Wolfe_C1 (float): Wolfe condition constant on the function
decrease (linesearch stopping criterea). Default to 1.0e-3.
decrease (linesearch stopping criteria). Default to 1.0e-3.
BFGS_Wolfe_C2 (float): Wolfe condition constant on the gradient
decrease (linesearch stopping criterea). Default to 0.7.
decrease (linesearch stopping criteria). Default to 0.7.
Save_relaxed (str): Filename for output results of relaxation.
No configuration will be saved if not specified.
Default to None.
Expand All @@ -317,9 +339,9 @@ def write_ini(self, mtp_filename="fitted.mtp", select=False, **kwargs):
lines.append(format_str.format("write-cfgs", kwargs.get("write_cfgs")))
if kwargs.get("write_cfgs_skip"):
lines.append(format_str.format("write-cfgs:skip", kwargs.get("write_cfgs_skip")))
if select is False:
if not kwargs.get("select"):
lines.append(format_str.format("select", "FALSE"))
elif select is True:
else:
lines.append(format_str.format("select", "TRUE"))
select_identifiers = [
"select:save-selected",
Expand Down Expand Up @@ -413,15 +435,19 @@ def write_ini(self, mtp_filename="fitted.mtp", select=False, **kwargs):
if MLIP:
lines.append(
format_str.format(
MTini_params.get("MLIP").get("name"), "mtpr", MTini_params.get("MLIP").get("comment")
MTini_params.get("MLIP").get("name"),
"mtpr",
MTini_params.get("MLIP").get("comment"),
)
)
mlip = MTini_params.get("MLIP")
if kwargs.get("load_from"):
load_from = mlip.get("load_from")
lines.append(
format_str.format(
"\t" + load_from.get("name"), kwargs.get("load_from"), load_from.get("comment")
"\t" + load_from.get("name"),
kwargs.get("load_from"),
load_from.get("comment"),
)
)
if kwargs.get("Calculate_EFS"):
Expand All @@ -443,14 +469,18 @@ def write_ini(self, mtp_filename="fitted.mtp", select=False, **kwargs):
write_cfgs = mlip.get("Write_cfgs")
lines.append(
format_str.format(
"\t" + write_cfgs.get("name"), kwargs.get("Write_cfgs"), write_cfgs.get("comment")
"\t" + write_cfgs.get("name"),
kwargs.get("Write_cfgs"),
write_cfgs.get("comment"),
)
)

if Driver:
lines.append(
format_str.format(
MTini_params.get("Driver").get("name"), str(Driver), MTini_params.get("Driver").get("comment")
MTini_params.get("Driver").get("name"),
str(Driver),
MTini_params.get("Driver").get("comment"),
)
)
driver = MTini_params.get("Driver").get(str(Driver))
Expand Down Expand Up @@ -510,7 +540,12 @@ def formatify(string):
.tolist()
)
virial_stress = [virial_stress[self.mtp_stress_order.index(n)] for n in self.vasp_stress_order]
struct = Structure(lattice=lattice, species=species, coords=position, coords_are_cartesian=True)
struct = Structure(
lattice=lattice,
species=species,
coords=position,
coords_are_cartesian=True,
)
d["structure"] = struct.as_dict()
d["outputs"]["energy"] = energy
assert size == struct.num_sites
Expand Down Expand Up @@ -577,9 +612,11 @@ def train(
"Please refer to https://mlip.skoltech.ru",
"for further detail.",
)
train_structures, train_forces, train_stresses = check_structures_forces_stresses(
train_structures, train_forces, train_stresses
)
(
train_structures,
train_forces,
train_stresses,
) = check_structures_forces_stresses(train_structures, train_forces, train_stresses)
train_pool = pool_from(train_structures, train_energies, train_forces, train_stresses)
elements = sorted(set(itertools.chain(*[struct.species for struct in train_structures])))
self.elements = [str(element) for element in elements]
Expand Down Expand Up @@ -661,6 +698,7 @@ def write_param(self, fitted_mtp="fitted.mtp", **kwargs):
Args:
fitted_mtp (str): Filename to store xml formatted parameters.
**kwargs: pass to write_ini method.
"""
if not self.param:
raise RuntimeError("The parameters should be provided.")
Expand Down
Loading

0 comments on commit 2c1bf6a

Please sign in to comment.