Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement per node target fitting #173

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(
"stress",
"dipole",
]
elif model_type == "AtomicTargetMACE":
self.implemented_properties = ["atomic_target"]
else:
raise ValueError(
f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported"
Expand Down Expand Up @@ -146,6 +148,9 @@ def _create_result_tensors(
if model_type in ["EnergyDipoleMACE", "DipoleMACE"]:
dipole = torch.zeros(num_models, 3, device=self.device)
dict_of_tensors.update({"dipole": dipole})
if model_type == "AtomicTargetMACE":
atomic_target = torch.zeros(num_models, num_atoms, device=self.device)
dict_of_tensors.update({"atomic_target": atomic_target})
return dict_of_tensors

# pylint: disable=dangerous-default-value
Expand Down Expand Up @@ -195,6 +200,8 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
ret_tensors["stress"][i] = out["stress"].detach()
if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]:
ret_tensors["dipole"][i] = out["dipole"].detach()
if self.model_type == "AtomicTargetMACE":
ret_tensors["atomic_target"][i] = out["node_energy"].detach()

self.results = {}
if self.model_type in ["MACE", "EnergyDipoleMACE"]:
Expand Down Expand Up @@ -250,6 +257,16 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
.cpu()
.numpy()
)
if self.model_type == "AtomicTargetMACE":
self.results["atomic_target"] = (
torch.mean(ret_tensors["atomic_target"], dim=0).cpu().numpy()
)
if self.num_models > 1:
self.results["atomic_target_var"] = (
torch.var(ret_tensors["atomic_target"], dim=0, unbiased=False)
.cpu()
.numpy()
)

def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
"""Extracts the descriptors from MACE model.
Expand Down
2 changes: 2 additions & 0 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
load_from_xyz,
random_train_valid_split,
test_config_types,
compute_average_node_target,
)

__all__ = [
Expand All @@ -22,4 +23,5 @@
"config_from_atoms_list",
"AtomicData",
"compute_average_E0s",
"compute_average_node_target",
]
11 changes: 10 additions & 1 deletion mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class AtomicData(torch_geometric.data.Data):
virials: torch.Tensor
dipole: torch.Tensor
charges: torch.Tensor
node_target: torch.Tensor
weight: torch.Tensor
energy_weight: torch.Tensor
forces_weight: torch.Tensor
Expand All @@ -62,6 +63,7 @@ def __init__(
virials: Optional[torch.Tensor], # [1,3,3]
dipole: Optional[torch.Tensor], # [, 3]
charges: Optional[torch.Tensor], # [n_nodes, ]
node_target: Optional[torch.Tensor], # [n_nodes, ]
):
# Check shapes
num_nodes = node_attrs.shape[0]
Expand All @@ -83,6 +85,7 @@ def __init__(
assert virials is None or virials.shape == (1, 3, 3)
assert dipole is None or dipole.shape[-1] == 3
assert charges is None or charges.shape == (num_nodes,)
assert node_target is None or node_target.shape == (num_nodes,)
# Aggregate data
data = {
"num_nodes": num_nodes,
Expand All @@ -103,6 +106,7 @@ def __init__(
"virials": virials,
"dipole": dipole,
"charges": charges,
"node_target": node_target,
}
super().__init__(**data)

Expand Down Expand Up @@ -189,7 +193,11 @@ def from_config(
if config.charges is not None
else None
)

node_target = (
torch.tensor(config.node_target, dtype=torch.get_default_dtype())
if config.node_target is not None
else None
)
return cls(
edge_index=torch.tensor(edge_index, dtype=torch.long),
positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()),
Expand All @@ -208,6 +216,7 @@ def from_config(
virials=virials,
dipole=dipole,
charges=charges,
node_target=node_target,
)


Expand Down
34 changes: 34 additions & 0 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Configuration:
virials: Optional[Virials] = None # eV
dipole: Optional[Vector] = None # Debye
charges: Optional[Charges] = None # atomic unit
node_target: Optional[Charges] = None
cell: Optional[Cell] = None
pbc: Optional[Pbc] = None

Expand Down Expand Up @@ -77,6 +78,7 @@ def config_from_atoms_list(
virials_key="virials",
dipole_key="dipole",
charges_key="charges",
atomic_target_key="hirshfeld_volumes",
config_type_weights: Dict[str, float] = None,
) -> Configurations:
"""Convert list of ase.Atoms into Configurations"""
Expand All @@ -94,6 +96,7 @@ def config_from_atoms_list(
virials_key=virials_key,
dipole_key=dipole_key,
charges_key=charges_key,
atomic_target_key=atomic_target_key,
config_type_weights=config_type_weights,
)
)
Expand All @@ -108,6 +111,7 @@ def config_from_atoms(
virials_key="virials",
dipole_key="dipole",
charges_key="charges",
atomic_target_key="hirshfeld_volumes",
config_type_weights: Dict[str, float] = None,
) -> Configuration:
"""Convert ase.Atoms to Configuration"""
Expand All @@ -121,6 +125,7 @@ def config_from_atoms(
dipole = atoms.info.get(dipole_key, None) # Debye
# Charges default to 0 instead of None if not found
charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit
atomic_target = atoms.arrays.get(atomic_target_key, None)
atomic_numbers = np.array(
[ase.data.atomic_numbers[symbol] for symbol in atoms.symbols]
)
Expand Down Expand Up @@ -158,6 +163,7 @@ def config_from_atoms(
virials=virials,
dipole=dipole,
charges=charges,
node_target=atomic_target,
weight=weight,
energy_weight=energy_weight,
forces_weight=forces_weight,
Expand Down Expand Up @@ -194,6 +200,7 @@ def load_from_xyz(
virials_key: str = "virials",
dipole_key: str = "dipole",
charges_key: str = "charges",
atomic_target_key: str = 'hirshfeld_volumes',
extract_atomic_energies: bool = False,
) -> Tuple[Dict[int, float], Configurations]:
atoms_list = ase.io.read(file_path, index=":")
Expand Down Expand Up @@ -239,6 +246,7 @@ def load_from_xyz(
virials_key=virials_key,
dipole_key=dipole_key,
charges_key=charges_key,
atomic_target_key=atomic_target_key,
)
return atomic_energies_dict, configs

Expand Down Expand Up @@ -271,3 +279,29 @@ def compute_average_E0s(
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = 0.0
return atomic_energies_dict

def compute_average_node_target(
collections_train: Configurations, z_table: AtomicNumberTable,
) -> Tuple[Dict[int, float], float]:
"""
Function to compute the average node target and node std of each chemical element
returns two dictionaries with average and scale
"""
len_train = len(collections_train)
len_zs = len(z_table)
elementwise_targets = {}
for i in range(len_train):
for j in range(len(collections_train[i].atomic_numbers)):
z = collections_train[i].atomic_numbers[j]
if z not in elementwise_targets.keys():
elementwise_targets[z] = []
elementwise_targets[z].append(collections_train[i].node_target[j])

atomic_energies_dict = {}
atomic_scales = []
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = np.mean(elementwise_targets[z])
atomic_scales.append((len(elementwise_targets[z]), np.std(elementwise_targets[z])))
# compute weighted average of scales with tuple element 0 ebing the weight and element 1 the value to average
scale = np.average([x[1] for x in atomic_scales], weights=[x[0] for x in atomic_scales])
return atomic_energies_dict, scale
2 changes: 2 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
WeightedEnergyForcesVirialsLoss,
WeightedForcesLoss,
WeightedHuberEnergyForcesStressLoss,
PerNodesLoss,
)
from .models import (
MACE,
Expand Down Expand Up @@ -101,4 +102,5 @@
"compute_mean_std_atomic_inter_energy",
"compute_avg_num_neighbors",
"compute_fixed_charge_dipole",
"PerNodesLoss",
]
18 changes: 18 additions & 0 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Te
# return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # []


def mean_squared_error_per_atom(ref: Batch, pred: TensorDict) -> torch.Tensor:
# energy: [n_graphs, ]
return torch.mean(torch.square(ref["node_target"] - pred["node_energy"])) # []

class WeightedEnergyForcesLoss(torch.nn.Module):
def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None:
super().__init__()
Expand Down Expand Up @@ -258,3 +262,17 @@ def __repr__(self):
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})"
)


class PerNodesLoss(torch.nn.Module):
def __init__(self, loss_scale=100) -> None:
super().__init__()
self.loss_scale = loss_scale

def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
return self.loss_scale * mean_squared_error_per_atom(ref, pred)

def __repr__(self):
return (
f"{self.__class__.__name__}"
)
13 changes: 13 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=str,
default="charges",
)
parser.add_argument(
"--atomic_target_key",
help="Key of atomic target in training xyz",
type=str,
default=None,
)
parser.add_argument(
"--atomic_target_loss_scale",
help="Prefactor of loss function for atomic target",
type=float,
default=100,
)

# Loss and optimization
parser.add_argument(
Expand All @@ -286,6 +298,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"dipole",
"huber",
"energy_forces_dipole",
"atomic_target_loss",
],
)
parser.add_argument(
Expand Down
18 changes: 18 additions & 0 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_dataset_from_xyz(
virials_key: str = "virials",
dipole_key: str = "dipoles",
charges_key: str = "charges",
atomic_target_key: str = 'hirshfeld_volumes',
) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]:
"""Load training and test dataset from xyz file"""
atomic_energies_dict, all_train_configs = data.load_from_xyz(
Expand All @@ -47,6 +48,7 @@ def get_dataset_from_xyz(
virials_key=virials_key,
dipole_key=dipole_key,
charges_key=charges_key,
atomic_target_key=atomic_target_key,
extract_atomic_energies=True,
)
logging.info(
Expand All @@ -62,6 +64,7 @@ def get_dataset_from_xyz(
virials_key=virials_key,
dipole_key=dipole_key,
charges_key=charges_key,
atomic_target_key=atomic_target_key,
extract_atomic_energies=False,
)
logging.info(
Expand All @@ -85,6 +88,7 @@ def get_dataset_from_xyz(
forces_key=forces_key,
dipole_key=dipole_key,
charges_key=charges_key,
atomic_target_key=atomic_target_key,
extract_atomic_energies=False,
)
# create list of tuples (config_type, list(Atoms))
Expand Down Expand Up @@ -198,6 +202,12 @@ def create_error_table(
"RMSE MU / mDebye / atom",
"rel MU RMSE %",
]
elif table_type == "PerAtomTargetRMSE":
table.field_names = [
"config_type",
"RMSE Target / atom",
"rel Target RMSE %",
]
for name, subset in all_collections:
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
Expand Down Expand Up @@ -315,4 +325,12 @@ def create_error_table(
f"{metrics['rel_rmse_mu']:.1f}",
]
)
elif table_type == "PerAtomTargetRMSE":
table.add_row(
[
name,
f"{metrics['rmse_per_nodes_target']:.1f}",
f"{metrics['rel_rmse_nodes_target']:.1f}",
]
)
return table
Loading