-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcluster.py
78 lines (62 loc) · 2.2 KB
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
class Cluster:
"""
Abstract implement of cluster class
"""
def __init__(self, data: np.ndarray, centers: np.ndarray):
"""
Main constructor of cluster class with random init centers
:param data: ndarray, the data set to be clustered
:param centers: ndarray, centers of Clusters
"""
self.data = data
self.centers = centers
self.labels = None
@staticmethod
def dist(point_1: np.ndarray, point_2: np.ndarray) -> np.ndarray:
"""
Method to calc the square euclidean distance
:param point_1: ndarray with shape (n, m) that contains first points
:param point_2: ndarray with shape (n, m) that contains second points
:return: ndarray with shape (n, ) of euclidean distances
"""
return np.sum(np.square(point_1 - point_2))
def set_centers(self, centers: np.ndarray):
"""
Change the centers values
:param centers: ndarray new centers values gives by develop
"""
if self.centers.shape[1] != centers.shape[1]:
raise "No valid centers"
self.centers = centers
def calc_SSW(self) -> float:
"""
Method to calc the SSW coefficient
:return: float, SSW coefficient
"""
if self.labels is None:
raise "No labels allocated, first call classify method"
SSW = 0
for index, center in enumerate(self.labels):
SSW += Cluster.dist(self.centers[center], self.data[index])
return SSW
def calc_TSS(self) -> float:
"""
Method to calc the TSS coefficient
:return: float, TSS coefficient
"""
return float(np.sum(self.dist(self.data, self.data.mean(0))))
def calc_SSB(self) -> float:
"""
Method to calc the SSB coefficient
:return: float, SSB coefficient
"""
return self.calc_TSS() - self.calc_SSW()
def get_metrics(self) -> dict:
"""
Method to get SSW, TSS and SSB
:return: dict with SSW, TSS, SSB
"""
SSW = self.calc_SSW()
TSS = self.calc_TSS()
return {"SSW": SSW, "TSS": TSS, "SSB": TSS - SSW}