K-means聚类模型
K-means聚类是一种常用的无监督学习算法,用于将数据集分为K个互不相交的子集(簇),每个子集内部的样本相似度较高,而不同子集之间的样本相似度较低。本文将详细介绍K-means聚类算法的原理、步骤及其Java实现,包括测试方法。
K-means聚类算法原理
K-means聚类算法的目标是通过迭代优化,使每个簇内的样本尽可能紧密,而簇间的距离尽可能远。算法的基本步骤如下:
- 初始化:随机选择K个初始质心(Centroids)。
- 分配样本:将每个样本分配到距离最近的质心所属的簇。
- 更新质心:计算每个簇的质心,作为该簇的新质心。
- 迭代:重复步骤2和3,直到质心位置不再变化或达到最大迭代次数。
Java实现K-means聚类
数据点类
首先,定义一个表示数据点的类:
public class DataPoint {
private double x;
private double y;
public DataPoint(double x, double y) {
this.x = x;
this.y = y;
}
public double getX() {
return x;
}
public double getY() {
return y;
}
public double distanceTo(DataPoint other) {
return Math.sqrt(Math.pow(this.x - other.x, 2) + Math.pow(this.y - other.y, 2));
}
}
K-means算法类
接下来,定义K-means算法的实现类:
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class KMeans {
private int k; // 簇的数量
private int maxIterations; // 最大迭代次数
private List<DataPoint> dataPoints; // 数据点列表
private List<DataPoint> centroids; // 质心列表
public KMeans(int k, int maxIterations) {
this.k = k;
this.maxIterations = maxIterations;
this.dataPoints = new ArrayList<>();
this.centroids = new ArrayList<>();
}
public void addDataPoint(DataPoint point) {
dataPoints.add(point);
}
public void fit() {
// 随机初始化质心
Random random = new Random();
for (int i = 0; i < k; i++) {
centroids.add(dataPoints.get(random.nextInt(dataPoints.size())));
}
for (int iteration = 0; iteration < maxIterations; iteration++) {
List<List<DataPoint>> clusters = new ArrayList<>();
for (int i = 0; i < k; i++) {
clusters.add(new ArrayList<>());
}
// 分配数据点到最近的质心
for (DataPoint point : dataPoints) {
double minDistance = Double.MAX_VALUE;
int closestCentroid = -1;
for (int i = 0; i < k; i++) {
double distance = point.distanceTo(centroids.get(i));
if (distance < minDistance) {
minDistance = distance;
closestCentroid = i;
}
}
clusters.get(closestCentroid).add(point);
}
// 更新质心
boolean centroidsChanged = false;
for (int i = 0; i < k; i++) {
double sumX = 0;
double sumY = 0;
List<DataPoint> cluster = clusters.get(i);
for (DataPoint point : cluster) {
sumX += point.getX();
sumY += point.getY();
}
DataPoint newCentroid = new DataPoint(sumX / cluster.size(), sumY / cluster.size());
if (!newCentroid.equals(centroids.get(i))) {
centroids.set(i, newCentroid);
centroidsChanged = true;
}
}
if (!centroidsChanged) {
break;
}
}
}
public List<DataPoint> getCentroids() {
return centroids;
}
}
测试方法
最后,编写一个测试方法来验证K-means算法的实现:
public class KMeansTest {
public static void main(String[] args) {
KMeans kMeans = new KMeans(3, 100);
// 添加测试数据点
kMeans.addDataPoint(new DataPoint(1, 1));
kMeans.addDataPoint(new DataPoint(2, 1));
kMeans.addDataPoint(new DataPoint(4, 3));
kMeans.addDataPoint(new DataPoint(5, 4));
kMeans.addDataPoint(new DataPoint(8, 8));
kMeans.addDataPoint(new DataPoint(9, 8));
// 执行聚类
kMeans.fit();
// 输出质心
List<DataPoint> centroids = kMeans.getCentroids();
for (int i = 0; i < centroids.size(); i++) {
System.out.println("Centroid " + (i + 1) + ": (" + centroids.get(i).getX() + ", " + centroids.get(i).getY() + ")");
}
}
}
结论
本文详细介绍了K-means聚类算法的原理和步骤,并提供了完整的Java实现代码和测试方法。通过这些代码示例,读者可以更好地理解K-means算法的工作机制,并在实际项目中应用该算法进行数据聚类。K-means聚类广泛应用于图像处理、市场营销、文本分类等领域,是数据科学家和机器学习工程师必备的工具之一。