-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcluster.py
26 lines (22 loc) · 970 Bytes
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import pickle
import time
from sklearn.cluster import AgglomerativeClustering, KMeans
from util.arg import cluster_args
from util.encode import encode_dataset
def cluster(dataset, args):
model = KMeans(args.cluster) if args.method == 'kmeans' else AgglomerativeClustering(args.cluster)
model.fit([example[args.encoding + '_encoding'] for example in dataset])
for i, example in enumerate(dataset):
example['cluster'] = model.labels_[i]
args = cluster_args()
start_time = time.time()
dataset = encode_dataset('train', args)
print(f'Dataset size: train -> {len(dataset):d} ;')
print(f'Load dataset finished, cost {time.time() - start_time:.4f}s ;')
print('Start clustering ...')
start_time = time.time()
cluster(dataset, args)
print(f'Clustering costs {time.time() - start_time:.2f}s ;')
with open(os.path.join('data', args.dataset, f'train.{args.method}.{args.cluster}.{args.encoding}.bin'), 'wb') as file:
pickle.dump(dataset, file)