From 05d15e6bdcee6dc8eff0047735c2db59edcf70fc Mon Sep 17 00:00:00 2001 From: FreyrS Date: Wed, 28 Aug 2019 11:31:19 +0200 Subject: [PATCH] Commenting --- source/data_preparation/04b-make_ligand_tfrecords.py | 1 + source/masif_ligand/masif_ligand_evaluate_test.py | 7 +++++-- source/masif_ligand/masif_ligand_train.py | 9 ++++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/source/data_preparation/04b-make_ligand_tfrecords.py b/source/data_preparation/04b-make_ligand_tfrecords.py index 8a260382..1a1bf514 100644 --- a/source/data_preparation/04b-make_ligand_tfrecords.py +++ b/source/data_preparation/04b-make_ligand_tfrecords.py @@ -52,6 +52,7 @@ for i, pdb in enumerate(train_pdbs): print("Working on", pdb) try: + # Load precomputed data input_feat = np.load( os.path.join(precom_dir, pdb + "_", "p1_input_feat.npy") ) diff --git a/source/masif_ligand/masif_ligand_evaluate_test.py b/source/masif_ligand/masif_ligand_evaluate_test.py index 36b57de5..38132e00 100644 --- a/source/masif_ligand/masif_ligand_evaluate_test.py +++ b/source/masif_ligand/masif_ligand_evaluate_test.py @@ -16,6 +16,7 @@ """ params = masif_opts["ligand"] +# Load testing data testing_data = tf.contrib.data.TFRecordDataset( os.path.join(params["tfrecords_dir"], "testing_data_sequenceSplit_30.tfrecord") ) @@ -30,7 +31,7 @@ with tf.Session() as sess: - # Build trained network + # Build network learning_obj = MaSIF_ligand( sess, params["max_distance"], @@ -39,6 +40,7 @@ feat_mask=params["feat_mask"], costfun=params["costfun"], ) + # Load pretrained network learning_obj.saver.restore(learning_obj.session, output_model) num_test_samples = 290 @@ -62,6 +64,7 @@ pdb_logits_softmax = [] pdb_labels = [] for ligand in range(n_ligands): + # Rows indicate point number and columns ligand type pocket_points = np.where(labels[:, ligand] != 0.0)[0] label = np.max(labels[:, ligand]) - 1 pocket_labels = np.zeros(7, dtype=np.float32) @@ -77,7 +80,7 @@ samples_data_loss = [] # Make 100 predictions for i in range(100): - # sample = samples[int(i*32):int((i+1)*32)] + # Sample pocket randomly sample = np.random.choice(pocket_points, 32, replace=False) feed_dict = { learning_obj.input_feat: data_element[0][sample, :, :], diff --git a/source/masif_ligand/masif_ligand_train.py b/source/masif_ligand/masif_ligand_train.py index 5ae8b37d..aaaba3a6 100644 --- a/source/masif_ligand/masif_ligand_train.py +++ b/source/masif_ligand/masif_ligand_train.py @@ -84,7 +84,6 @@ npoints = pocket_points.shape[0] if npoints < 32: continue - sample = np.random.choice(pocket_points, 32, replace=False) # For evaluating take the first 32 points of the pocket feed_dict = { learning_obj.input_feat: data_element[0][pocket_points[:32], :, :], @@ -112,7 +111,9 @@ num_epoch, np.mean(training_losses), np.median(training_losses) ) ) + # Generate confusion matrix training_conf_mat = confusion_matrix(training_ytrue, training_ypred) + # Compute accuracy training_accuracy = float(np.sum(np.diag(training_conf_mat))) / np.sum( training_conf_mat ) @@ -123,6 +124,7 @@ validation_ytrue = [] validation_ypred = [] print("Calulating validation loss") + # Compute accuracy on the validation set for num_val_sample in range(num_validation_samples): try: data_element = sess.run(validation_next_element) @@ -138,7 +140,6 @@ npoints = pocket_points.shape[0] if npoints < 32: continue - sample = np.random.choice(pocket_points, 32, replace=False) feed_dict = { learning_obj.input_feat: data_element[0][pocket_points[:32], :, :], learning_obj.rho_coords: np.expand_dims(data_element[1], -1)[ @@ -195,7 +196,6 @@ npoints = pocket_points.shape[0] if npoints < 32: continue - sample = np.random.choice(pocket_points, 32, replace=False) feed_dict = { learning_obj.input_feat: data_element[0][pocket_points[:32], :, :], learning_obj.rho_coords: np.expand_dims(data_element[1], -1)[ @@ -228,9 +228,11 @@ ) print(testing_conf_mat) print("Testing accuracy:", testing_accuracy) + # Stop training if number of iterations has reached 40000 if total_iterations == 40000: break + # Train the network training_losses = [] training_ytrue = [] training_ypred = [] @@ -251,6 +253,7 @@ npoints = pocket_points.shape[0] if npoints < 32: continue + # Sample 32 points randomly sample = np.random.choice(pocket_points, 32, replace=False) feed_dict = { learning_obj.input_feat: data_element[0][sample, :, :],