diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index cefa4ac1..7dbff5f2 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -70,6 +70,8 @@ def __init__( "stress", "dipole", ] + elif model_type == "AtomicTargetMACE": + self.implemented_properties = ["atomic_target"] else: raise ValueError( f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" @@ -146,6 +148,9 @@ def _create_result_tensors( if model_type in ["EnergyDipoleMACE", "DipoleMACE"]: dipole = torch.zeros(num_models, 3, device=self.device) dict_of_tensors.update({"dipole": dipole}) + if model_type == "AtomicTargetMACE": + atomic_target = torch.zeros(num_models, num_atoms, device=self.device) + dict_of_tensors.update({"atomic_target": atomic_target}) return dict_of_tensors # pylint: disable=dangerous-default-value @@ -195,6 +200,8 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): ret_tensors["stress"][i] = out["stress"].detach() if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: ret_tensors["dipole"][i] = out["dipole"].detach() + if self.model_type == "AtomicTargetMACE": + ret_tensors["atomic_target"][i] = out["node_energy"].detach() self.results = {} if self.model_type in ["MACE", "EnergyDipoleMACE"]: @@ -250,6 +257,16 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): .cpu() .numpy() ) + if self.model_type == "AtomicTargetMACE": + self.results["atomic_target"] = ( + torch.mean(ret_tensors["atomic_target"], dim=0).cpu().numpy() + ) + if self.num_models > 1: + self.results["atomic_target_var"] = ( + torch.var(ret_tensors["atomic_target"], dim=0, unbiased=False) + .cpu() + .numpy() + ) def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): """Extracts the descriptors from MACE model. diff --git a/mace/data/__init__.py b/mace/data/__init__.py index 0d0c9bf2..491a98ee 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -9,6 +9,7 @@ load_from_xyz, random_train_valid_split, test_config_types, + compute_average_node_target, ) __all__ = [ @@ -22,4 +23,5 @@ "config_from_atoms_list", "AtomicData", "compute_average_E0s", + "compute_average_node_target", ] diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 31170d3d..ceeebdca 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -37,6 +37,7 @@ class AtomicData(torch_geometric.data.Data): virials: torch.Tensor dipole: torch.Tensor charges: torch.Tensor + node_target: torch.Tensor weight: torch.Tensor energy_weight: torch.Tensor forces_weight: torch.Tensor @@ -62,6 +63,7 @@ def __init__( virials: Optional[torch.Tensor], # [1,3,3] dipole: Optional[torch.Tensor], # [, 3] charges: Optional[torch.Tensor], # [n_nodes, ] + node_target: Optional[torch.Tensor], # [n_nodes, ] ): # Check shapes num_nodes = node_attrs.shape[0] @@ -83,6 +85,7 @@ def __init__( assert virials is None or virials.shape == (1, 3, 3) assert dipole is None or dipole.shape[-1] == 3 assert charges is None or charges.shape == (num_nodes,) + assert node_target is None or node_target.shape == (num_nodes,) # Aggregate data data = { "num_nodes": num_nodes, @@ -103,6 +106,7 @@ def __init__( "virials": virials, "dipole": dipole, "charges": charges, + "node_target": node_target, } super().__init__(**data) @@ -189,7 +193,11 @@ def from_config( if config.charges is not None else None ) - + node_target = ( + torch.tensor(config.node_target, dtype=torch.get_default_dtype()) + if config.node_target is not None + else None + ) return cls( edge_index=torch.tensor(edge_index, dtype=torch.long), positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), @@ -208,6 +216,7 @@ def from_config( virials=virials, dipole=dipole, charges=charges, + node_target=node_target, ) diff --git a/mace/data/utils.py b/mace/data/utils.py index 9a701fc5..244fbcf7 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -37,6 +37,7 @@ class Configuration: virials: Optional[Virials] = None # eV dipole: Optional[Vector] = None # Debye charges: Optional[Charges] = None # atomic unit + node_target: Optional[Charges] = None cell: Optional[Cell] = None pbc: Optional[Pbc] = None @@ -77,6 +78,7 @@ def config_from_atoms_list( virials_key="virials", dipole_key="dipole", charges_key="charges", + atomic_target_key="hirshfeld_volumes", config_type_weights: Dict[str, float] = None, ) -> Configurations: """Convert list of ase.Atoms into Configurations""" @@ -94,6 +96,7 @@ def config_from_atoms_list( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_target_key=atomic_target_key, config_type_weights=config_type_weights, ) ) @@ -108,6 +111,7 @@ def config_from_atoms( virials_key="virials", dipole_key="dipole", charges_key="charges", + atomic_target_key="hirshfeld_volumes", config_type_weights: Dict[str, float] = None, ) -> Configuration: """Convert ase.Atoms to Configuration""" @@ -121,6 +125,7 @@ def config_from_atoms( dipole = atoms.info.get(dipole_key, None) # Debye # Charges default to 0 instead of None if not found charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit + atomic_target = atoms.arrays.get(atomic_target_key, None) atomic_numbers = np.array( [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] ) @@ -158,6 +163,7 @@ def config_from_atoms( virials=virials, dipole=dipole, charges=charges, + node_target=atomic_target, weight=weight, energy_weight=energy_weight, forces_weight=forces_weight, @@ -194,6 +200,7 @@ def load_from_xyz( virials_key: str = "virials", dipole_key: str = "dipole", charges_key: str = "charges", + atomic_target_key: str = 'hirshfeld_volumes', extract_atomic_energies: bool = False, ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") @@ -239,6 +246,7 @@ def load_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_target_key=atomic_target_key, ) return atomic_energies_dict, configs @@ -271,3 +279,29 @@ def compute_average_E0s( for i, z in enumerate(z_table.zs): atomic_energies_dict[z] = 0.0 return atomic_energies_dict + +def compute_average_node_target( + collections_train: Configurations, z_table: AtomicNumberTable, +) -> Tuple[Dict[int, float], float]: + """ + Function to compute the average node target and node std of each chemical element + returns two dictionaries with average and scale + """ + len_train = len(collections_train) + len_zs = len(z_table) + elementwise_targets = {} + for i in range(len_train): + for j in range(len(collections_train[i].atomic_numbers)): + z = collections_train[i].atomic_numbers[j] + if z not in elementwise_targets.keys(): + elementwise_targets[z] = [] + elementwise_targets[z].append(collections_train[i].node_target[j]) + + atomic_energies_dict = {} + atomic_scales = [] + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = np.mean(elementwise_targets[z]) + atomic_scales.append((len(elementwise_targets[z]), np.std(elementwise_targets[z]))) + # compute weighted average of scales with tuple element 0 ebing the weight and element 1 the value to average + scale = np.average([x[1] for x in atomic_scales], weights=[x[0] for x in atomic_scales]) + return atomic_energies_dict, scale \ No newline at end of file diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index d01a38ed..f83d5fa1 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -28,6 +28,7 @@ WeightedEnergyForcesVirialsLoss, WeightedForcesLoss, WeightedHuberEnergyForcesStressLoss, + PerNodesLoss, ) from .models import ( MACE, @@ -101,4 +102,5 @@ "compute_mean_std_atomic_inter_energy", "compute_avg_num_neighbors", "compute_fixed_charge_dipole", + "PerNodesLoss", ] diff --git a/mace/modules/loss.py b/mace/modules/loss.py index 7c8f5ad0..a636054a 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -77,6 +77,10 @@ def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Te # return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # [] +def mean_squared_error_per_atom(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + return torch.mean(torch.square(ref["node_target"] - pred["node_energy"])) # [] + class WeightedEnergyForcesLoss(torch.nn.Module): def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: super().__init__() @@ -258,3 +262,17 @@ def __repr__(self): f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" ) + + +class PerNodesLoss(torch.nn.Module): + def __init__(self, loss_scale=100) -> None: + super().__init__() + self.loss_scale = loss_scale + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return self.loss_scale * mean_squared_error_per_atom(ref, pred) + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + ) \ No newline at end of file diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 192306c3..c89a4b1e 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -271,6 +271,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str, default="charges", ) + parser.add_argument( + "--atomic_target_key", + help="Key of atomic target in training xyz", + type=str, + default=None, + ) + parser.add_argument( + "--atomic_target_loss_scale", + help="Prefactor of loss function for atomic target", + type=float, + default=100, + ) # Loss and optimization parser.add_argument( @@ -286,6 +298,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "dipole", "huber", "energy_forces_dipole", + "atomic_target_loss", ], ) parser.add_argument( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 41ea7f77..af2074f0 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -36,6 +36,7 @@ def get_dataset_from_xyz( virials_key: str = "virials", dipole_key: str = "dipoles", charges_key: str = "charges", + atomic_target_key: str = 'hirshfeld_volumes', ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" atomic_energies_dict, all_train_configs = data.load_from_xyz( @@ -47,6 +48,7 @@ def get_dataset_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_target_key=atomic_target_key, extract_atomic_energies=True, ) logging.info( @@ -62,6 +64,7 @@ def get_dataset_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_target_key=atomic_target_key, extract_atomic_energies=False, ) logging.info( @@ -85,6 +88,7 @@ def get_dataset_from_xyz( forces_key=forces_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_target_key=atomic_target_key, extract_atomic_energies=False, ) # create list of tuples (config_type, list(Atoms)) @@ -198,6 +202,12 @@ def create_error_table( "RMSE MU / mDebye / atom", "rel MU RMSE %", ] + elif table_type == "PerAtomTargetRMSE": + table.field_names = [ + "config_type", + "RMSE Target / atom", + "rel Target RMSE %", + ] for name, subset in all_collections: data_loader = torch_geometric.dataloader.DataLoader( dataset=[ @@ -315,4 +325,12 @@ def create_error_table( f"{metrics['rel_rmse_mu']:.1f}", ] ) + elif table_type == "PerAtomTargetRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_per_nodes_target']:.1f}", + f"{metrics['rel_rmse_nodes_target']:.1f}", + ] + ) return table diff --git a/mace/tools/train.py b/mace/tools/train.py index 545f077e..6fc912d1 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -181,6 +181,11 @@ def train( logging.info( f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" ) + elif log_errors == "PerAtomTargetRMSE": + error_target = eval_metrics["rmse_per_nodes_target"] * 1e3 + logging.info( + f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_target={error_target:.1f}" + ) if log_wandb: wandb_log_dict = { "epoch": epoch, @@ -274,6 +279,9 @@ def evaluate( E_computed = False delta_es_list = [] delta_es_per_atom_list = [] + per_node_targets_list = [] + delta_per_nodes_target_list = [] + per_target_output_computed = False delta_fs_list = [] Fs_computed = False fs_list = [] @@ -307,11 +315,19 @@ def evaluate( total_loss += to_numpy(loss).item() if output.get("energy") is not None and batch.energy is not None: - E_computed = True - delta_es_list.append(batch.energy - output["energy"]) - delta_es_per_atom_list.append( - (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) - ) + if output_args["per_node_output"] is False: + E_computed = True + delta_es_list.append(batch.energy - output["energy"]) + delta_es_per_atom_list.append( + (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) + ) + else: + E_computed = False + per_target_output_computed = True + delta_per_nodes_target_list.append( + batch.node_target - output["node_energy"] + ) + per_node_targets_list.append(batch.node_target) if output.get("forces") is not None and batch.forces is not None: Fs_computed = True delta_fs_list.append(batch.forces - output["forces"]) @@ -353,6 +369,20 @@ def evaluate( aux["rmse_e"] = compute_rmse(delta_es) aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) aux["q95_e"] = compute_q95(delta_es) + if per_target_output_computed: + delta_per_nodes_target = to_numpy( + torch.cat(delta_per_nodes_target_list, dim=0) + ) + per_node_targets = to_numpy(torch.cat(per_node_targets_list, dim=0)) + aux["mae_per_nodes_target"] = compute_mae(delta_per_nodes_target) + aux["rmse_per_nodes_target"] = compute_rmse(delta_per_nodes_target) + aux["q95_per_nodes_target"] = compute_q95(delta_per_nodes_target) + aux["rel_mae_nodes_target"] = compute_rel_mae( + delta_per_nodes_target, per_node_targets + ) + aux["rel_rmse_nodes_target"] = compute_rel_rmse( + delta_per_nodes_target, per_node_targets + ) if Fs_computed: delta_fs = to_numpy(torch.cat(delta_fs_list, dim=0)) fs = to_numpy(torch.cat(fs_list, dim=0)) diff --git a/scripts/run_train.py b/scripts/run_train.py index 6e1f9c2d..7a5bbf51 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -50,6 +50,7 @@ def main() -> None: ) config_type_weights = {"Default": 1.0} + # Data preparation collections, atomic_energies_dict = get_dataset_from_xyz( train_path=args.train_file, @@ -64,6 +65,7 @@ def main() -> None: virials_key=args.virials_key, dipole_key=args.dipole_key, charges_key=args.charges_key, + atomic_target_key=args.atomic_target_key, ) logging.info( @@ -71,6 +73,18 @@ def main() -> None: f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]" ) + if args.atomic_target_key is not None: + logging.info(f"Training atomic target model on {args.atomic_target_key}") + args.error_table = "PerAtomTargetRMSE" + compute_per_node_target = True + args.E0s = None + args.compute_forces = False # no forces for atomic target pred + args.compute_stress = False + args.loss = "atomic_target_loss" + atomic_energies_dict = None + else: + compute_per_node_target = False + # Atomic number table # yapf: disable z_table = tools.get_atomic_number_table_from_zs( @@ -121,9 +135,17 @@ def main() -> None: f"E0s specified invalidly, error {e} occured" ) from e else: - raise RuntimeError( - "E0s not found in training file and not specified in command line" - ) + if compute_per_node_target: + logging.info( + "Computing average Atomic Target for scaling" + ) + atomic_energies_dict, atomic_scale = data.compute_average_node_target( + collections.train, z_table + ) + else: + raise RuntimeError( + "E0s not found in training file and not specified in command line" + ) atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) @@ -188,6 +210,8 @@ def main() -> None: forces_weight=args.forces_weight, dipole_weight=args.dipole_weight, ) + elif args.loss == "atomic_target_loss": + loss_fn = modules.PerNodesLoss(args.atomic_target_loss_scale) else: # Unweighted Energy and Forces loss by default loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) @@ -210,6 +234,7 @@ def main() -> None: "virials": compute_virials, "stress": args.compute_stress, "dipoles": compute_dipole, + "per_node_output": compute_per_node_target, } logging.info(f"Selected the following outputs: {output_args}") @@ -249,6 +274,9 @@ def main() -> None: if args.scaling == "no_scaling": std = 1.0 logging.info("No scaling selected") + elif compute_per_node_target: + mean = 0.0 + std = atomic_scale else: mean, std = modules.scaling_classes[args.scaling]( train_loader, atomic_energies @@ -418,6 +446,8 @@ def main() -> None: logging.info( f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" ) + elif args.loss == "atomic_target_loss": + loss_fn_energy = modules.PerNodesLoss(args.atomic_target_loss_scale) else: loss_fn_energy = modules.WeightedEnergyForcesLoss( energy_weight=args.swa_energy_weight, diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 7598ef28..541ed061 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -39,6 +39,9 @@ def fixture_fitting_configs(): c.info["REF_stress"] = np.random.normal(0.1, size=6) fit_configs.append(c) + for c in fit_configs: + c.new_array("REF_atomic_target", np.random.normal(0.1, size=len(c))) + return fit_configs @@ -274,3 +277,37 @@ def test_run_train_no_stress(tmp_path, fitting_configs): -0.02801561982433547, ] assert np.allclose(Es, ref_Es) + +def test_run_train_node_target(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["atomic_target_key"] = "REF_atomic_target" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(tmp_path / "MACE.model", device="cpu") +