Skip to content

Commit

Permalink
Commenting
Browse files Browse the repository at this point in the history
  • Loading branch information
FreyrS committed Aug 28, 2019
1 parent 52c7994 commit 05d15e6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
1 change: 1 addition & 0 deletions source/data_preparation/04b-make_ligand_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
for i, pdb in enumerate(train_pdbs):
print("Working on", pdb)
try:
# Load precomputed data
input_feat = np.load(
os.path.join(precom_dir, pdb + "_", "p1_input_feat.npy")
)
Expand Down
7 changes: 5 additions & 2 deletions source/masif_ligand/masif_ligand_evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

params = masif_opts["ligand"]
# Load testing data
testing_data = tf.contrib.data.TFRecordDataset(
os.path.join(params["tfrecords_dir"], "testing_data_sequenceSplit_30.tfrecord")
)
Expand All @@ -30,7 +31,7 @@


with tf.Session() as sess:
# Build trained network
# Build network
learning_obj = MaSIF_ligand(
sess,
params["max_distance"],
Expand All @@ -39,6 +40,7 @@
feat_mask=params["feat_mask"],
costfun=params["costfun"],
)
# Load pretrained network
learning_obj.saver.restore(learning_obj.session, output_model)

num_test_samples = 290
Expand All @@ -62,6 +64,7 @@
pdb_logits_softmax = []
pdb_labels = []
for ligand in range(n_ligands):
# Rows indicate point number and columns ligand type
pocket_points = np.where(labels[:, ligand] != 0.0)[0]
label = np.max(labels[:, ligand]) - 1
pocket_labels = np.zeros(7, dtype=np.float32)
Expand All @@ -77,7 +80,7 @@
samples_data_loss = []
# Make 100 predictions
for i in range(100):
# sample = samples[int(i*32):int((i+1)*32)]
# Sample pocket randomly
sample = np.random.choice(pocket_points, 32, replace=False)
feed_dict = {
learning_obj.input_feat: data_element[0][sample, :, :],
Expand Down
9 changes: 6 additions & 3 deletions source/masif_ligand/masif_ligand_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@
npoints = pocket_points.shape[0]
if npoints < 32:
continue
sample = np.random.choice(pocket_points, 32, replace=False)
# For evaluating take the first 32 points of the pocket
feed_dict = {
learning_obj.input_feat: data_element[0][pocket_points[:32], :, :],
Expand Down Expand Up @@ -112,7 +111,9 @@
num_epoch, np.mean(training_losses), np.median(training_losses)
)
)
# Generate confusion matrix
training_conf_mat = confusion_matrix(training_ytrue, training_ypred)
# Compute accuracy
training_accuracy = float(np.sum(np.diag(training_conf_mat))) / np.sum(
training_conf_mat
)
Expand All @@ -123,6 +124,7 @@
validation_ytrue = []
validation_ypred = []
print("Calulating validation loss")
# Compute accuracy on the validation set
for num_val_sample in range(num_validation_samples):
try:
data_element = sess.run(validation_next_element)
Expand All @@ -138,7 +140,6 @@
npoints = pocket_points.shape[0]
if npoints < 32:
continue
sample = np.random.choice(pocket_points, 32, replace=False)
feed_dict = {
learning_obj.input_feat: data_element[0][pocket_points[:32], :, :],
learning_obj.rho_coords: np.expand_dims(data_element[1], -1)[
Expand Down Expand Up @@ -195,7 +196,6 @@
npoints = pocket_points.shape[0]
if npoints < 32:
continue
sample = np.random.choice(pocket_points, 32, replace=False)
feed_dict = {
learning_obj.input_feat: data_element[0][pocket_points[:32], :, :],
learning_obj.rho_coords: np.expand_dims(data_element[1], -1)[
Expand Down Expand Up @@ -228,9 +228,11 @@
)
print(testing_conf_mat)
print("Testing accuracy:", testing_accuracy)
# Stop training if number of iterations has reached 40000
if total_iterations == 40000:
break

# Train the network
training_losses = []
training_ytrue = []
training_ypred = []
Expand All @@ -251,6 +253,7 @@
npoints = pocket_points.shape[0]
if npoints < 32:
continue
# Sample 32 points randomly
sample = np.random.choice(pocket_points, 32, replace=False)
feed_dict = {
learning_obj.input_feat: data_element[0][sample, :, :],
Expand Down

0 comments on commit 05d15e6

Please sign in to comment.