-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathdataset.py
81 lines (64 loc) · 2.37 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
'''
Copyright (c) 2020, Abdelrahman Hosny <[email protected]>
All rights reserved.
BSD 3-Clause License
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
import math
import networkx as nx
import numpy as np
from sklearn.preprocessing import LabelEncoder
from dgl import DGLGraph
__all__ = ['ICTunerDataset']
class ICTunerDataset(object):
def __init__(self):
super(ICTunerDataset, self).__init__()
self.graphs = []
self.design_names = []
self.labels = []
def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graphs)
def __getitem__(self, idx):
"""Get the i^th sample.
Paramters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, int)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
@property
def num_classes(self):
"""Number of classes."""
return len(self.design_names)
@property
def dataset_labels(self):
return self.labels
@property
def dataset_design_names(self):
return self.design_names
def _create_labels(self):
labelencoder = LabelEncoder()
self.labels = labelencoder.fit_transform(self.design_names)
def add_design(self, graph, design_name):
self.graphs.append(graph)
self.design_names.append(design_name)
self._create_labels()
def labels_map(self):
labels_map = {}
for i in range(len(self.labels)):
labels_map[self.labels[i]] = self.design_names[i]
return labels_map