forked from babaling/DRPreter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdrug_graph.py
114 lines (93 loc) · 3.93 KB
/
drug_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from rdkit import Chem
import numpy as np
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Data
from dgllife.utils import *
def atom_to_feature_vector(atom):
"""
Converts rdkit atom object to feature list of indices
:param mol: rdkit atom object
:return: list
8 features are canonical, 2 features are from OGB
"""
featurizer_funcs = ConcatFeaturizer([atom_type_one_hot,
atom_degree_one_hot,
atom_implicit_valence_one_hot,
atom_formal_charge,
atom_num_radical_electrons,
atom_hybridization_one_hot,
atom_is_aromatic,
atom_total_num_H_one_hot,
atom_is_in_ring,
atom_chirality_type_one_hot,
])
atom_feature = featurizer_funcs(atom)
return atom_feature
def bond_to_feature_vector(bond):
"""
Converts rdkit bond object to feature list of indices
:param mol: rdkit bond object
:return: list
"""
featurizer_funcs = ConcatFeaturizer([bond_type_one_hot,
# bond_is_conjugated,
# bond_is_in_ring,
# bond_stereo_one_hot,
])
bond_feature = featurizer_funcs(bond)
return bond_feature
def smiles2graph(mol):
"""
Converts SMILES string or rdkit's mol object to graph Data object without remove salt
:input: SMILES string (str)
:return: graph object
"""
if isinstance(mol, Chem.rdchem.Mol):
pass
else:
mol = Chem.MolFromSmiles(mol)
# atoms
atom_features_list = []
for atom in mol.GetAtoms():
atom_features_list.append(atom_to_feature_vector(atom))
x = np.array(atom_features_list, dtype=np.int64)
# bonds
# num_bond_features = 3 # bond type, bond stereo, is_conjugated
num_bond_features = 1 # bond type
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = bond_to_feature_vector(bond)
# add edges in both directions
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = np.array(edges_list, dtype=np.int64).T
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = np.array(edge_features_list, dtype=np.int64)
else: # mol has no bonds
edge_index = np.empty((2, 0), dtype=np.int64)
edge_attr = np.empty((0, num_bond_features), dtype=np.int64)
graph = Data(x=torch.tensor(x, dtype=torch.float),
edge_index=torch.tensor(edge_index, dtype=torch.long),
edge_attr=torch.tensor(edge_attr), dtype=torch.float)
return graph
def save_drug_graph():
smiles = pd.read_csv('Data/Drug/drug_smiles.csv')
drug_dict = {}
for i in range(len(smiles)):
drug_dict[smiles.iloc[i, 0]] = smiles2graph(smiles.iloc[i, 2])
np.save('Data/Drug/drug_feature_graph.npy', drug_dict)
return drug_dict
if __name__ == '__main__':
# graph = smiles2graph('O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5')
# print(graph.x.shape)
# print(graph.edge_attr.shape)
save_drug_graph()