import copy
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
import random
class DBSCAN:
def __init__(self, dataset, eps, min_point):
self.dataset = dataset
self.eps = eps
self.min_point = min_point
self.N, self.M = self.dataset.shape
def get_neighbour_num(self, i):
N = list()
temp = np.sum((self.dataset - self.dataset[i]) ** 2, axis=1) ** 0.5
N = np.argwhere(temp <= eps).flatten().tolist()
return N
def train(self):
'''存放每一个点是否被遍历过'''
visited = [0] * self.N
'''存放每一个点属于的簇'''
cluster = [0] * self.N
k = 1
'''存放每一个点的邻居'''
neighbour_list = []
for i in range(self.N):
neighbour = self.get_neighbour_num(i)
neighbour_list.append(neighbour)
while True:
if visited.count(0) == 0:
break
'''随机选取中心簇'''
index = random.choice([i for i in range(len(visited)) if visited[i] == 0])
if len(neighbour_list[index]) < self.min_point:
visited[index] = -1
cluster[index] = -1
else:
Queen = [index]
cluster_list = [index]
while Queen:
start = Queen[0]
if len(neighbour_list[start]) >= self.min_point:
for i in range(len(neighbour_list[start])):
if neighbour_list[start][i] not in cluster_list and neighbour_list[start][i] not in Queen:
Queen.append(neighbour_list[start][i])
cluster_list.append(start)
visited[start] = -1
Queen.pop(0)
for i in range(len(cluster_list)):
cluster[cluster_list[i]] = k
k += 1
return cluster
if __name__ == '__main__':
x1, y1 = datasets.make_circles(n_samples=2000, factor=.6, noise=.02)
x2, y2 = datasets.make_blobs(n_samples=400, n_features=2, centers=[[1.2, 1.2]], cluster_std=[[.1]], random_state=9)
dataset = np.concatenate((x1, x2))
target = np.concatenate((y1, y2))
eps = 0.08
min_point = 10
model = DBSCAN(dataset, eps, min_point)
clusters = model.train()
print(list(target))
print(clusters)
plt.scatter(np.array([dataset[i][0] for i in range(len(clusters)) if clusters[i] == 1]),
np.array([dataset[i][1] for i in range(len(clusters)) if clusters[i] == 1]))
plt.scatter(np.array([dataset[i][0] for i in range(len(clusters)) if clusters[i] == 2]),
np.array([dataset[i][1] for i in range(len(clusters)) if clusters[i] == 2]))
plt.scatter(np.array([dataset[i][0] for i in range(len(clusters)) if clusters[i] == 3]),
np.array([dataset[i][1] for i in range(len(clusters)) if clusters[i] == 3]))
plt.show()