-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathscikit-dbscan-example.py
64 lines (47 loc) · 1.93 KB
/
scikit-dbscan-example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding: utf-8 -*-
"""
This script is used to validate that my implementation of DBSCAN produces
the same results as the implementation found in scikit-learn.
It's based on the scikit-learn example code, here:
http://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html
@author: Chris McCormick
"""
from sklearn.datasets.samples_generator import make_blobs
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
from dbscan import MyDBSCAN
# Create three gaussian blobs to use as our clustering data.
centers = [[1, 1], [-1, -1], [1, -1]]
X, labels_true = make_blobs(n_samples=750, centers=centers, cluster_std=0.4,
random_state=0)
X = StandardScaler().fit_transform(X)
###############################################################################
# My implementation of DBSCAN
#
# Run my DBSCAN implementation.
print 'Running my implementation...'
my_labels = MyDBSCAN(X, eps=0.3, MinPts=10)
###############################################################################
# Scikit-learn implementation of DBSCAN
#
print 'Runing scikit-learn implementation...'
db = DBSCAN(eps=0.3, min_samples=10).fit(X)
skl_labels = db.labels_
# Scikit learn uses -1 to for NOISE, and starts cluster labeling at 0. I start
# numbering at 1, so increment the skl cluster numbers by 1.
for i in range(0, len(skl_labels)):
if not skl_labels[i] == -1:
skl_labels[i] += 1
###############################################################################
# Did we get the same results?
num_disagree = 0
# Go through each label and make sure they match (print the labels if they
# don't)
for i in range(0, len(skl_labels)):
if not skl_labels[i] == my_labels[i]:
print 'Scikit learn:', skl_labels[i], 'mine:', my_labels[i]
num_disagree += 1
if num_disagree == 0:
print 'PASS - All labels match!'
else:
print 'FAIL -', num_disagree, 'labels don\'t match.'