Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quick script to obtain molecular embeddings #27

Open
fengjiang02 opened this issue Dec 16, 2024 · 3 comments
Open

Quick script to obtain molecular embeddings #27

fengjiang02 opened this issue Dec 16, 2024 · 3 comments

Comments

@fengjiang02
Copy link

Is there a shortcut script that inputs a molecule's SMILES or its 2D/3D structure and outputs the embedding of that molecule? Thank you.

@chao1224
Copy link
Owner

chao1224 commented Jan 6, 2025

Hi @fengjiang02,

I can provide a Python pseudocode, which should cover sufficient information for you to get the embedding.

SMILES / 2D Graph

from torch_geometric.data import Data, Batch
from src_classification.datasets.molecule_datasets import mol_to_graph_data_obj_simple

SMILES = "..."
molecule = Chem.from_SMILES(SMILES)
pyg_graph_data = mol_to_graph_data_obj_simple(molecule)
batch_data = Batch.from_data_list([pyg_graph_data])

# set up model
molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim,
                     JK=args.JK, drop_ratio=args.dropout_ratio,
                     gnn_type=args.gnn_type)
model = GNN_graphpred(args=args, num_tasks=num_tasks,
                      molecule_model=molecule_model).to(device)
model.from_pretrained("checkpoint_path")

representation_2D = model(batch_data)

3D Graph

# suppose you have the 3D information stored in the RDKit molecule object
molecule_3D = ...
pyg_graph_data = Data(x=molecule_3D.x, positions=molecule_3D.positions)
batch_data = out_customized_batch.from_data_list([pyg_graph_data])

model = SchNet(
    hidden_channels=args.emb_dim, num_filters=args.num_filters, num_interactions=args.num_interactions,
    num_gaussians=args.num_gaussians, cutoff=args.cutoff, atomref=None, readout=args.readout).to(device)
model.from_pretrained("checkpoint_path")

representation_3D = model(batch_data)

Feel free to let me know if this is still unclear to you.

@fengjiang02
Copy link
Author

Thank you :) I have also noticed some of your models on Hugging Face. I wonder if it is possible to directly obtain molecular structure information through Hugging Face? For example, your “MoleculeSTM_graph_property_prediction“ model makes it easy to get embeddings, but it is not yet clear whether it contains structural information...

@chao1224
Copy link
Owner

chao1224 commented Jan 8, 2025

Hi @fengjiang02,

I see your questions. It is definitely do-able, just I haven't implemented them yet, and users may need to hack the code a little bit. It's not very hard for users who are familiar with pytorch-geometric.

Are you able to solve the problems using the codes above?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants