Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode molecules starting from their SMILES #26

Open
marcosbodio opened this issue Sep 9, 2024 · 7 comments
Open

Encode molecules starting from their SMILES #26

marcosbodio opened this issue Sep 9, 2024 · 7 comments

Comments

@marcosbodio
Copy link

Hello, I would like to know if it is possible to use GraphMVP to encode molecule starting from their SMILES. I have read this issue, but that does not help much. I would be really grateful if you could provide some explanation, and ideally an example. Thank you!

@chao1224
Copy link
Owner

Hi @marcosbodio,

Thank you for your question.

  • SMILES is a string representation for the molecule topology, and is not exactly the same as the 2D graph.
  • So if we want to reuse the implementation of the current repo, then the answer is no.
  • However, if we expand GraphMVP from 2D-3D to topology-geometry, then the answer is yes. What you need to do is to replace the 2D graph + 2D GNN (GIN in our paper) with SMILES + BERT (or any other sequence encoder).

@marcosbodio
Copy link
Author

Hi @chao1224, thank you for your answer. I see in your paper that you have Table 5 where you list results on DTA tasks with Davis and KIBA. These datasets contains SMILES of molecules, so how did you use GraphMVP (or GraphMVP-G, GraphMVP-C) on these datasets? It would be very useful to see the code, because that would clarify what is the proper way of using your model starting from the SMILES of molecule.

@chao1224
Copy link
Owner

Hi @marcosbodio,

Sure, you can check this python script, specifically, this line assigns which dataset to use.

@marcosbodio
Copy link
Author

Hi @chao1224, I have looked at the script that you linked above, and I think that is for fine tuning your model, which I would prefer to avoid.

I was hoping to use a checkpoint of your model, for example output/3D_hybrid_02_masking/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.3_EBM_dot_prod_0.1_normalize_l2_detach_target_2_100_0/pretraining_model.pth in GraphMVP_simple_features_for_classification.zip (shared here)

I wonder if I could do something like this:

import torch
from rdkit import Chem
from rdkit.Chem.rdDistGeom import EmbedMolecule

from src_classification.GEOM_dataset_preparation import mol_to_graph_data_obj_simple_3D

smiles = 'Cn1cnc(c1)C(=O)c1ccc(CN2[C@H](Cc3ccccn3)C(=O)Nc3cc(Cl)ccc3C2=O)cc1'
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
EmbedMolecule(mol=mol)
data = mol_to_graph_data_obj_simple_3D(mol)

and then feed data to the model loaded from the checkpoint to compute an embedding of the SMILES. What do you think?

@chao1224
Copy link
Owner

chao1224 commented Sep 20, 2024

Hi @marcosbodio,

Yes, I think this is right if you want to use the 3D representation.

  1. When we create the checkpoints, we save the following modules (code):
            saver_dict = {
                'model': molecule_model_2D.state_dict(),
                'model_3D': molecule_model_3D.state_dict(),
                'AE_2D_3D_model': AE_2D_3D_model.state_dict(),
                'AE_3D_2D_model': AE_3D_2D_model.state_dict(),
            }
  1. What you wrote above can be fed into the model_3D.
  2. If you only want to use the 2D checkpoint, which is model above, then you can follow this pseudocode:
smiles = 'Cn1cnc(c1)C(=O)c1ccc(CN2[C@H](Cc3ccccn3)C(=O)Nc3cc(Cl)ccc3C2=O)cc1'
mol = Chem.MolFromSmiles(smiles)

data = mol_to_graph_data_obj_simple(mol)

where mol_to_graph_data_obj_simple is in this function.

@marcosbodio
Copy link
Author

HI @chao1224 ,

I have tried to load one of your model checkpoint, but I do not see model_3D. Here is what I did:

model_path = 'output/3D_hybrid_02_masking/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.3_EBM_dot_prod_0.1_normalize_l2_detach_target_2_100_0/pretraining_model.pth'
model = torch.load(f=model_path, map_location=torch.device('cpu'))
print(model.keys())
print('model_3D' in model)

where model_path is from your file GraphMVP_simple_features_for_classification.zip (shared here)

The previous code prints the following:

odict_keys(['x_embedding1.weight', 'x_embedding2.weight', 'gnns.0.mlp.0.weight', 'gnns.0.mlp.0.bias', 'gnns.0.mlp.2.weight', 'gnns.0.mlp.2.bias', 'gnns.0.edge_embedding1.weight', 'gnns.0.edge_embedding2.weight', 'gnns.1.mlp.0.weight', 'gnns.1.mlp.0.bias', 'gnns.1.mlp.2.weight', 'gnns.1.mlp.2.bias', 'gnns.1.edge_embedding1.weight', 'gnns.1.edge_embedding2.weight', 'gnns.2.mlp.0.weight', 'gnns.2.mlp.0.bias', 'gnns.2.mlp.2.weight', 'gnns.2.mlp.2.bias', 'gnns.2.edge_embedding1.weight', 'gnns.2.edge_embedding2.weight', 'gnns.3.mlp.0.weight', 'gnns.3.mlp.0.bias', 'gnns.3.mlp.2.weight', 'gnns.3.mlp.2.bias', 'gnns.3.edge_embedding1.weight', 'gnns.3.edge_embedding2.weight', 'gnns.4.mlp.0.weight', 'gnns.4.mlp.0.bias', 'gnns.4.mlp.2.weight', 'gnns.4.mlp.2.bias', 'gnns.4.edge_embedding1.weight', 'gnns.4.edge_embedding2.weight', 'batch_norms.0.weight', 'batch_norms.0.bias', 'batch_norms.0.running_mean', 'batch_norms.0.running_var', 'batch_norms.0.num_batches_tracked', 'batch_norms.1.weight', 'batch_norms.1.bias', 'batch_norms.1.running_mean', 'batch_norms.1.running_var', 'batch_norms.1.num_batches_tracked', 'batch_norms.2.weight', 'batch_norms.2.bias', 'batch_norms.2.running_mean', 'batch_norms.2.running_var', 'batch_norms.2.num_batches_tracked', 'batch_norms.3.weight', 'batch_norms.3.bias', 'batch_norms.3.running_mean', 'batch_norms.3.running_var', 'batch_norms.3.num_batches_tracked', 'batch_norms.4.weight', 'batch_norms.4.bias', 'batch_norms.4.running_mean', 'batch_norms.4.running_var', 'batch_norms.4.num_batches_tracked'])
False

Am I loading the wrong checkpoint?

@chao1224
Copy link
Owner

Hi @marcosbodio ,

I need to double-check the checkpoint files when I got time. Meanwhile, you should be able to use this checkpoint, which is one of the SOTA PaiNN pretraining methods (paper link)).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants