diff --git a/play_ground.ipynb b/play_ground.ipynb new file mode 100644 index 0000000..dbd0ffa --- /dev/null +++ b/play_ground.ipynb @@ -0,0 +1,637 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from math import ceil\n", + "import json\n", + "import os\n", + "import sys\n", + "import random\n", + "from itertools import combinations\n", + "\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.metrics import roc_auc_score, roc_curve, classification_report\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from datasets import link_prediction\n", + "from layers import MeanAggregator, LSTMAggregator, MaxPoolAggregator, MeanPoolAggregator\n", + "from models import DGNN, AAGNN\n", + "from models_variants import EAAGNN, EAACGNN\n", + "import models\n", + "import utils" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------\n", + "Reading edge dataset from HE_T1087_84_Default_Extended_1_2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hasegawatai/college/research/DigitalPathology/Digital-Pathology/datasets/link_prediction.py:410: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " distances_close_to_edges[target][source] = distance\n", + "/Users/hasegawatai/college/research/DigitalPathology/Digital-Pathology/datasets/link_prediction.py:409: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " distances_close_to_edges[source][target] = distance\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<<<<<<<<<<<<< id lym epi fib inf x y gt \\\n", + "0 0 0.0 0.0 0.0 0.0 127 1682 apoptosis / civiatte body \n", + "1 1 0.0 0.0 0.0 0.0 147 1690 apoptosis / civiatte body \n", + "2 2 0.0 0.0 0.0 0.0 315 1656 apoptosis / civiatte body \n", + "3 3 0.0 1.0 0.0 0.0 1027 1763 3 \n", + "4 4 0.0 1.0 0.0 0.0 1122 1879 3 \n", + "... ... ... ... ... ... ... ... ... \n", + "1008 1008 0.0 1.0 0.0 0.0 487 1911 3 \n", + "1009 1009 0.0 1.0 0.0 0.0 536 1681 3 \n", + "1010 1010 0.0 1.0 0.0 0.0 528 1710 3 \n", + "1011 1011 0.0 1.0 0.0 0.0 454 1675 3 \n", + "1012 1012 1.0 0.0 0.0 0.0 427 1637 1 \n", + "\n", + " lym_density epi_density fib_density inf_density \n", + "0 23.675918 20.109769 0.000000 0.0 \n", + "1 17.464249 24.360860 0.000000 0.0 \n", + "2 47.169906 51.491275 0.000000 0.0 \n", + "3 32.930879 20.498191 0.000000 0.0 \n", + "4 40.468889 61.252811 0.000000 0.0 \n", + "... ... ... ... ... \n", + "1008 0.000000 41.093058 0.000000 0.0 \n", + "1009 50.803543 25.424011 62.936476 0.0 \n", + "1010 0.000000 29.885451 0.000000 0.0 \n", + "1011 39.292955 45.286881 0.000000 0.0 \n", + "1012 22.781464 35.109144 0.000000 0.0 \n", + "\n", + "[1013 rows x 12 columns]\n", + "(1013, 4)\n", + "Finished reading data.\n", + "Setting up graph.\n", + "self.features.shape: torch.Size([1013, 4])\n", + "Finished setting up graph.\n", + "Setting up examples.\n", + "self.mode != 'train'\n", + "Finished setting up examples.\n", + "Dataset properties:\n", + "Mode: val\n", + "Number of vertices: 1013\n", + "Number of edges: 2995\n", + "Number of positive/negative datapoints: 190/2805\n", + "Number of examples/datapoints: 2995\n", + "--------------------------------\n" + ] + } + ], + "source": [ + "# Set up arguments for datasets, models and training.\n", + "is_train = True\n", + "is_test = False\n", + "is_val = True\n", + "is_debug = True\n", + "conf_device = None\n", + "hidden_dim = [8]\n", + "batch_size = 32\n", + "dataset_folder = \"gen2_synthetic_csv_test_files/70_30_split\"\n", + "\n", + "num_layers = len(hidden_dim) + 1\n", + "\n", + "\n", + "if False and torch.cuda.is_available():\n", + " device = 'cuda:0'\n", + "else:\n", + " device = 'cpu'\n", + "conf_device = device\n", + "\n", + "# Get the dataset, dataloader and model.\n", + "if not is_val and not is_test:\n", + " dataset_args = ('train', num_layers)\n", + "\n", + "if is_val:\n", + " dataset_args = ('val', num_layers)\n", + "\n", + "if is_test:\n", + " dataset_args = ('test', num_layers)\n", + "\n", + "datasets = utils.get_dataset_gcn(dataset_args, dataset_folder, is_debug=is_debug)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# load data\n", + "loaders = []\n", + "for i in range(len(datasets)):\n", + " loaders.append(DataLoader(dataset=datasets[i], batch_size=batch_size,\n", + " shuffle=True, collate_fn=datasets[i].collate_wrapper))\n", + " \n", + "loader = DataLoader(dataset=datasets[0], batch_size=batch_size,\n", + " shuffle=False, collate_fn=datasets[0].collate_wrapper)\n", + "for (idx, batch) in enumerate(loader):\n", + " adj, features, edge_features, adj_relative_cos, edges, labels, dist = batch\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# load model\n", + "directory = \"models/\"\n", + "fname = \"aagnn_c2a1_rev_drop03_epoch50_hid8_out1_70_30_split_exp_saved_model.pth\"\n", + "path = os.path.join(directory, fname)\n", + "model = AAGNN(4, 8, 1, 0.3, 'cpu')\n", + "sigmoid = nn.Sigmoid()\n", + "model.load_state_dict(torch.load(path))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " ROC-AUC score: 0.8710\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.7185\n", + " ROC-AUC score: 0.8046\n", + " ROC-AUC score: 0.9167\n", + " ROC-AUC score: 0.8500\n", + " ROC-AUC score: 0.9821\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.6500\n", + " ROC-AUC score: 0.9833\n", + " ROC-AUC score: 0.6500\n", + " ROC-AUC score: 0.9852\n", + " ROC-AUC score: 0.7926\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.8506\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.9333\n", + " ROC-AUC score: 0.8387\n", + " ROC-AUC score: 0.8000\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.8710\n", + " ROC-AUC score: 0.8736\n", + " ROC-AUC score: 0.9032\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.8000\n", + " ROC-AUC score: 0.8833\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.9000\n", + " ROC-AUC score: 0.6407\n", + " ROC-AUC score: 0.6897\n", + " ROC-AUC score: 0.8125\n", + " ROC-AUC score: 0.9885\n", + " ROC-AUC score: 0.6452\n", + " ROC-AUC score: 0.8667\n", + " ROC-AUC score: 0.6322\n", + " ROC-AUC score: 0.7667\n", + " ROC-AUC score: 0.8391\n", + " ROC-AUC score: 0.7143\n", + " ROC-AUC score: 0.9425\n", + " ROC-AUC score: 0.6333\n", + " ROC-AUC score: 0.9333\n", + " ROC-AUC score: 0.8000\n", + " ROC-AUC score: 0.2833\n", + " ROC-AUC score: 0.1935\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.3678\n", + " ROC-AUC score: 0.6778\n", + " ROC-AUC score: 0.7816\n", + " ROC-AUC score: 0.7436\n", + " ROC-AUC score: 0.8161\n", + " ROC-AUC score: 0.6667\n", + " ROC-AUC score: 0.6333\n", + " ROC-AUC score: 0.9911\n", + " ROC-AUC score: 0.9833\n", + " ROC-AUC score: 0.2903\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.7630\n", + " ROC-AUC score: 0.6827\n", + " ROC-AUC score: 0.8506\n", + " ROC-AUC score: 0.8333\n", + " ROC-AUC score: 0.7321\n", + " ROC-AUC score: 0.4023\n", + " ROC-AUC score: 0.9770\n", + " ROC-AUC score: 0.7259\n", + " ROC-AUC score: 0.8333\n", + " ROC-AUC score: 0.9080\n", + " ROC-AUC score: 0.4516\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.7857\n", + " ROC-AUC score: 0.4000\n", + " ROC-AUC score: 1.0000\n", + " ROC-AUC score: 0.8046\n", + " ROC-AUC score: 0.8667\n", + " ROC-AUC score: 0.8387\n", + " ROC-AUC score: 0.5167\n", + " ROC-AUC score: 0.8393\n", + " ROC-AUC score: 0.9677\n", + " ROC-AUC score: 0.9481\n", + "Threshold: 0.1000, accuracy: 0.7586\n", + "ROC-AUC score: 0.8162\n", + "Classification report\n", + " precision recall f1-score support\n", + "\n", + " 0.0 0.9786 0.7615 0.8565 17212\n", + " 1.0 0.1443 0.7068 0.2396 979\n", + "\n", + " accuracy 0.7586 18191\n", + " macro avg 0.5614 0.7342 0.5481 18191\n", + "weighted avg 0.9337 0.7586 0.8233 18191\n", + "\n" + ] + } + ], + "source": [ + "# run model\n", + "t = 0.1\n", + "stats_per_batch = 5\n", + "results_dir = \"results/70_30_split\"\n", + "\n", + "running_loss, total_loss = 0.0, 0.0\n", + "num_correct, num_examples = 0, 0\n", + "total_correct, total_examples, total_batches = 0, 0, 0\n", + "y_true, y_scores, y_pred = [], [], []\n", + "for i in range(len(datasets)):\n", + " num_batches = int(ceil(len(datasets[i]) / batch_size))\n", + " total_batches += num_batches\n", + " # Added by Jorge\n", + " edge_pred, neg_pred, classes, coords = [], [], None, None\n", + " # --------------\n", + " with torch.no_grad():\n", + " for (idx, batch) in enumerate(loaders[i]):\n", + " adj, features, densities, adj_relative_cos, edges, labels, dist = batch\n", + " labels = labels.to(device)\n", + " \n", + " adj_relative_cos = {node: {(i, j): cos.to(device) for (i, j), cos in node_adj.items()} for node, node_adj in adj_relative_cos.items()}\n", + " adj, features = adj.to(device), features.to(device)\n", + " out = model(features, adj, adj_relative_cos)\n", + " \n", + " all_pairs = torch.mm(out, out.t())\n", + " all_pairs = sigmoid(all_pairs)\n", + " scores = 1 - all_pairs[edges.T]\n", + " predictions = (scores >= t).long()\n", + "\n", + " num_correct += torch.sum(predictions == labels.long()).item()\n", + " total_correct += torch.sum(predictions == labels.long()).item()\n", + " num_examples += len(labels)\n", + " total_examples += len(labels)\n", + " y_true.extend(labels.detach().cpu().numpy())\n", + " y_scores.extend(scores.detach().cpu().numpy())\n", + " y_pred.extend(predictions.detach().cpu().numpy())\n", + " if (idx + 1) % stats_per_batch == 0:\n", + " accuracy = num_correct / num_examples\n", + " if (torch.sum(labels.long() == 0).item() > 0) and (torch.sum(labels.long() == 1).item() > 0):\n", + " area = roc_auc_score(labels.detach().cpu().numpy(), scores.detach().cpu().numpy())\n", + " print(' ROC-AUC score: {:.4f}'.format(area))\n", + " num_correct, num_examples = 0, 0\n", + " num_correct, num_examples = 0, 0\n", + "\n", + "total_accuracy = total_correct / total_examples\n", + "print('Threshold: {:.4f}, accuracy: {:.4f}'.format(t, total_correct / total_examples))\n", + "y_true = np.array(y_true).flatten()\n", + "y_scores = np.array(y_scores).flatten()\n", + "y_pred = np.array(y_pred).flatten()\n", + "report = classification_report(y_true, y_pred, digits=4)\n", + "area = roc_auc_score(y_true, y_scores)\n", + "print('ROC-AUC score: {:.4f}'.format(area))\n", + "print('Classification report\\n', report)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "features = torch.Tensor([\n", + " [1, 2, 3, 4],\n", + " [1, 1, 1, 1],\n", + " [1, 2, 2, 1]\n", + "])\n", + "\n", + "weight0 = torch.Tensor([\n", + " [1, 1, 1, 1, 1],\n", + " [1, 2, 3, 4, 5],\n", + " [2, 2, 2, 2, 2],\n", + " [5, 4, 3, 2, 1]\n", + "])\n", + "\n", + "weight1 = torch.Tensor([\n", + " [1, 1, 1, 1, 1],\n", + " [0, 0, 0, 0, 0],\n", + " [2, 2, 2, 2, 2],\n", + " [0, 0, 0, 0, 0]\n", + "])\n", + "\n", + "edge_features = torch.Tensor([\n", + " [\n", + " [1, 2, 3],\n", + " [3, 2, 1],\n", + " [2, 2, 2]\n", + " ],\n", + " [\n", + " [1, 1, 1],\n", + " [2, 2, 2],\n", + " [3, 3, 3]\n", + " ]\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "x0 = torch.matmul(features, weight0)\n", + "x1 = torch.matmul(features, weight1)\n", + "x2 = torch.matmul(edge_features, x0)\n", + "output = torch.cat([xi for xi in x2], dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 83., 84., 85., 86., 87., 50., 49., 48., 47., 46.],\n", + " [117., 112., 107., 102., 97., 100., 98., 96., 94., 92.],\n", + " [100., 98., 96., 94., 92., 150., 147., 144., 141., 138.]])" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[29., 27., 25., 23., 21.],\n", + " [ 9., 9., 9., 9., 9.],\n", + " [12., 13., 14., 15., 16.]])" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x0" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[7., 7., 7., 7., 7.],\n", + " [3., 3., 3., 3., 3.],\n", + " [5., 5., 5., 5., 5.]])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x1" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 83., 84., 85., 86., 87.],\n", + " [117., 112., 107., 102., 97.],\n", + " [100., 98., 96., 94., 92.]],\n", + "\n", + " [[ 50., 49., 48., 47., 46.],\n", + " [100., 98., 96., 94., 92.],\n", + " [150., 147., 144., 141., 138.]]])" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x2" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 90., 91., 92., 93., 94.],\n", + " [120., 115., 110., 105., 100.],\n", + " [105., 103., 101., 99., 97.]],\n", + "\n", + " [[ 57., 56., 55., 54., 53.],\n", + " [103., 101., 99., 97., 95.],\n", + " [155., 152., 149., 146., 143.]]])" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x2 + x1" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "usage: ipykernel_launcher.py [-h] [--json JSON]\n", + " [--stats_per_batch STATS_PER_BATCH]\n", + " [--results_dir RESULTS_DIR]\n", + " [--saved_models_dir SAVED_MODELS_DIR]\n", + " [--task {unsupervised,link_prediction}]\n", + " [--agg_class {,,,}]\n", + " [--cuda] [--dropout DROPOUT]\n", + " [--hidden_dims [HIDDEN_DIMS ...]]\n", + " [--out_dim OUT_DIM] [--num_samples NUM_SAMPLES]\n", + " [--classifier {pos_sig,neg_sig,mlp}]\n", + " [--batch_size BATCH_SIZE] [--epochs EPOCHS]\n", + " [--lr LR] [--weight_decay WEIGHT_DECAY]\n", + " [--debug DEBUG] [--save] [--test] [--val]\n", + "ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9008 --control=9006 --hb=9005 --Session.signature_scheme=\"hmac-sha256\" --Session.key=b\"61db5cec-0146-49f0-ad3d-88a07cf77209\" --shell=9007 --transport=\"tcp\" --iopub=9009 --f=/var/folders/l9/5gsd4gnx6cs5qwd88wl8cwqh0000gn/T/tmp-563IB4fQk9BHxdP.json\n" + ] + }, + { + "ename": "SystemExit", + "evalue": "2", + "output_type": "error", + "traceback": [ + "An exception has occurred, use %tb to see the full traceback.\n", + "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hasegawatai/anaconda3/envs/DigitalPathology/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3259: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n", + " warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n" + ] + } + ], + "source": [ + "p = utils.parse_args()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "a = torch.Tensor([\n", + " [1, 2, 3],\n", + " [2, 3, 4],\n", + " [3, 4, 5]\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[1., 2., 3.],\n", + " [2., 3., 4.],\n", + " [3., 4., 5.]]])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a.reshape(-1, 3, 3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "ccaac013b6b6e9b32d0a0d6e96288a02ed656e9518d9a1239996f4b4a38ebffb" + }, + "kernelspec": { + "display_name": "Python 3.9.7 ('DigitalPathology')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/play_ground.py b/play_ground.py new file mode 100644 index 0000000..d98745d --- /dev/null +++ b/play_ground.py @@ -0,0 +1,144 @@ +from math import ceil +import json +import os +import sys +import random + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_auc_score, roc_curve, classification_report +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader + +from datasets import link_prediction +from layers import MeanAggregator, LSTMAggregator, MaxPoolAggregator, MeanPoolAggregator +from models import DGNN, AAGNN +from models_variants import EAAGNN, EAACGNN +import utils + +# Set up arguments for datasets, models and training. +config = utils.parse_args() +print(config) +print(config['classifier']) +# config['num_layers'] = len(config['hidden_dims']) + 1 + +# if config['cuda'] and torch.cuda.is_available(): +# device = 'cuda:0' +# else: +# device = 'cpu' +# config['device'] = device + + +# print(f"val: {config['val']}") +# print(f"test: {config['test']}") + + +# # Get the dataset, dataloader and model. +# if not config['val'] and not config['test']: +# dataset_args = ('train', config['num_layers']) + +# if config['val']: +# dataset_args = ('val', config['num_layers']) + +# if config['test']: +# dataset_args = ('test', config['num_layers']) + +# datasets = utils.get_dataset_gcn(dataset_args, config['dataset_folder'], is_debug=True) + + +# loader = DataLoader(dataset=datasets[0], batch_size=config['batch_size'], +# shuffle=False, collate_fn=datasets[0].collate_wrapper) + +# loaders = [] +# for i in range(len(datasets)): +# loaders.append(DataLoader(dataset=datasets[i], batch_size=config['batch_size'], +# shuffle=True, collate_fn=datasets[i].collate_wrapper)) + +# input_dim, output_dim = datasets[0].get_dims() +# model = AAGNN(input_dim, config['hidden_dims'][0], output_dim, +# config['dropout'], config['device']) +# model.to(config['device']) + + +# sigmoid = nn.Sigmoid() +# criterion = utils.get_criterion(config['task']) + +# optimizer = optim.Adam(model.parameters(), lr=config['lr'], +# weight_decay=config['weight_decay']) +# epochs = config['epochs'] +# epochs = 10 + +# #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1500, gamma=0.8) +# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 15, 20, 25], gamma=0.5) # Epoch decay +# model.train() +# print('--------------------------------') +# print('Training.') +# for epoch in range(epochs): +# print('Epoch {} / {}'.format(epoch+1, epochs)) + +# epoch_loss = 0.0 +# epoch_roc = 0.0 +# epoch_batches = 0 +# shuffle = list(range(len(loaders))) +# random.shuffle(shuffle) # Shuffle order of graphs +# for i in shuffle: +# num_batches = int(ceil(len(datasets[i]) / config['batch_size'])) +# epoch_batches += num_batches +# graph_roc = 0.0 +# running_loss = 0.0 +# for (idx, batch) in enumerate(loaders[i]): +# adj, adj_list, features, coords, edges, labels, dist, node_layers = batch +# labels = labels.to(device) +# optimizer.zero_grad() + +# adj_relative_cos = utils.get_relative_cos_list(adj_list, coords, device) +# adj, features = adj.to(device), features.to(device) +# out = model(features, adj, adj_relative_cos) + +# all_pairs = torch.mm(out, out.t()) +# all_pairs = sigmoid(all_pairs) +# scores = all_pairs[edges.T] +# loss = criterion(scores, labels.float()) +# loss.backward() +# optimizer.step() +# with torch.no_grad(): +# running_loss += loss.item() +# epoch_loss += loss.item() +# if (torch.sum(labels.long() == 0).item() > 0) and (torch.sum(labels.long() == 1).item() > 0): +# area = roc_auc_score(labels.detach().cpu().numpy(), scores.detach().cpu().numpy()) +# epoch_roc += area +# graph_roc += area +# running_loss /= num_batches +# print(' Graph {} / {}: loss {:.4f}'.format( +# i+1, len(datasets), running_loss)) +# print(' ROC-AUC score: {:.4f}'.format(graph_roc/num_batches)) + +# scheduler.step() +# print("Epoch avg loss: {}".format(epoch_loss / epoch_batches)) +# print("Epoch avg ROC_AUC score: {}".format(epoch_roc / epoch_batches)) + +# print('Finished training.') +# print('--------------------------------') + +# y_true, y_scores = [], [] +# for batch in loader: +# adj, adj_list, features, coords, edges, labels, dist, node_layers = batch +# labels = labels.to(device) +# adj_relative_cos = utils.get_relative_cos_list(adj_list, coords, device) +# adj, features = adj.to(device), features.to(device) +# out = model(features, adj, adj_relative_cos) +# all_pairs = torch.mm(out, out.t()) +# all_pairs = sigmoid(all_pairs) +# scores = all_pairs[edges.T] +# y_true.extend(labels.detach().cpu().numpy()) +# y_scores.extend(scores.detach().cpu().numpy()) + +# y_true = np.array(y_true).flatten() +# y_scores = np.array(y_scores).flatten() +# area = roc_auc_score(y_true, y_scores) + +# print(y_scores) +# print(y_scores[np.isnan(y_scores)]) +# print(area) \ No newline at end of file