目录
- 3.1 k近邻算法
- 3.2 k近邻模型
- 3.2.1 模型
- 3.2.2 距离度量
- 3.2.3 k值的选择
- 3.2.4 分类决策规则
- 3.3 k近邻算法的实现:kd树
- 3.3.1 构造kd树
- 3.3.2 搜索kd树
- 代码
3.1 k近邻算法
- k近邻法(k-nearest neighbor,k-NN):给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最近邻的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。KNN使用的模型实际上对应于特征空间的划分,没有显式的训练过程。
3.2 k近邻模型
3.2.1 模型
- 该模型有三个基本要素:距离度量,k值的选择,分类决策规则.当这三个要素确定后,便能对于任何一个新的输入实例,给出唯一确定的分类.这里用图片说明更清楚,对于训练集中的每一个样本,距离该点比其他点更近的所有点组成一片区域,叫做单元.每个样本都拥有一个单元,所有样本的单元最终构成对整个特征空间的划分,且对每个样本而言,它的标签就是该单元内所有点的标记.这样每个单元的样本点的标签也就是唯一确定的.
三要素
3.2.2 距离度量
距离:特征空间中两个实例点的距离是两个实例点相优程度的反映。设输入实例
距离定义为:
- 欧氏距离:
时的特殊情况。
- 曼哈顿距离:
时的特殊情况。
- 切比雪夫距离:
时的特殊情况。
3.2.3 k值的选择
- 较小的k值代表整体模型变得复杂,分类结果容易被噪声点影响,容易发生过拟合。
较大的k值代表整体模型变得简单,容易欠拟合。
在应用中,k值一般取一个比较小的数值,通常采用交叉验证法来选取最优的k值。
3.2.4 分类决策规则
- k近邻法中的分类决策规则往往是多数表决,即由输入实例的k个邻近的训练实例中的多数类决定输入实例的类。
- 多数表决规则
如果分类的损失函数为损失函数,分类函数为
那误分类的概率是
对给定的实例其最近邻的
个训练实例点构成集合
。如果涵盖
的区域类别是
,那误分类率是
要使误分类率最小即经验风险最小,就要使最大,所以多数表决规则等价于经验风险最小化。
3.3 k近邻算法的实现:kd树
3.3.1 构造kd树
- 输入: k 维空间数据集:
其中,
输出:kd树
- 开始:构造根节点。
选取为坐标轴,以训练集中的所有数据
- 重 复
对深度为的结点, 选择
为切分坐标轴,
, 以该结点区域中所有实例
坐标的中位数
作为切分点, 将区域分为两个子区域。
生成深度为 - 直到两个子 区域没有实例时停止。
- 举例
输入:训练集:
输出: kd 树
开始:选择为坐标轴,中位数为
, 即
为切分点, 切分整个区域
- 再次划分区域:
以为坐标轴,选择中位数, 左边区域为
, 右边区域为
。故左边区域切分点为
, 右边区域切分点坐标为
- 划分左边区域:
以为坐标轴,选择中位数,上边区域为
, 下边区域为
。故上边 区域切分点为
, 下边区域切分点坐标为
- 划分右边区域: 以
为坐标轴,选择中位数, 上边区域无实例点, 下边区域为
。故 下边区域切分点坐标为
- 最终划分结果
- kd树
- 至此算法完成
3.3.2 搜索kd树
- 用kd树的最近邻搜索
- 寻找 “当前最近点”
寻找最近邻的子结点作为目标点的“当前最近点”。 - 回溯
以目标点和“当前最近点” 的距离沿树根部进行回溯和迭代。 - 详细描述
输入:已构造的 kd 树, 目标点 x
输出: x 的最近邻
- 寻找“当前最近点”
- 从根结点出发, 递归访问 kd 树, 找出包含 x 的叶结点;
以此叶结点为“当前最近点"”;
- 回溯
- 若该结点比 “当前最近点” 距离目标点更近, 更新“当前最近点”;
- 当前最近点一定存在于该结点一个子结点对应的区域, 检查子结点 的父结点的另一子结点对应的区域是否有更近的点。
- 当回退到根结点时, 搜索结束, 最后的“当前最近点” 即为 x 的最近邻点。
- 举例
输入: kd 树, 目 标点;
输出:最近邻点 - 第一次回溯
- 第二次回溯,最近邻点:
- 如果实例点是随机分布的, kd 树搜索的平均计算复杂度是
, 这里
是 训练实例数。 kd 树更适用于训练实例数远大于空间维数时的 k 近邻搜索。当空间维委
接近训练实例数时,它的效率会迅速下降, 几乎接近线性扫描。
代码
import torch
import random
import matplotlib.pyplot as plt
class DrawTool():
"""画图类"""
# 画点[数据集,x点,离 x点 最近的点]
def drawPoint(self, points, x, nearestPoint):
XMax = max(points[:, 0]) # X 轴范围
YMax = max(points[:, 1]) # Y 轴范围
precision = max(XMax, YMax) // 10 + 1 # 坐标轴精度
#plt.rcParams['font.sans-serif'] = ['SimHei'] # 防止中文乱码
plt.scatter(points[:, 0], points[:, 1], label="data")
plt.scatter(x[0], x[1], c='c', marker='*', s=100, label="x(input)")
plt.scatter(nearestPoint[0], nearestPoint[1], c='r', label="nearest")
plt.xticks(torch.arange(0, XMax, precision)) # 设置 X 轴
plt.yticks(torch.arange(0, YMax, precision)) # 设置 Y 轴
plt.legend(loc='upper left')
plt.show()
class DataLoader():
""""数据加载类"""
# 初始化[creat:人造数据集,random:随机数据集]
def __init__(self, kind="creat"):
self.points = None
if (kind == "creat"):
self.x = [2, 5, 9, 4, 8, 7]
self.y = [3, 4, 6, 7, 1, 2]
elif kind == "random":
nums = random.randint(20, 40)
self.x = random.sample(range(0, 40), nums)
self.y = random.sample(range(0, 40), nums)
# 处理数据
def getData(self):
self.points = [[self.x[i], self.y[i]] for i in range(len(self.x))]
return self.points
# 得到一个与数据集不重复的点,作为 x 点
def getRandomPoint(self):
points = torch.tensor(self.points)
x, y, i = -1, -1, 0
while x == -1 or y == -1:
if x == -1 and i not in points[:, 0]:
x = i
if y == -1 and i not in points[:, 1]:
y = i
i += 2
return x, y
class KDNode():#二叉树
""""节点类"""
def __init__(self, point):
self.point = point
self.left = None
self.right = None
class KDTree():
"""KD树"""
def __init__(self):
self.root = None
self.nearestPoint = None
self.nearestDis = float('inf')
# 创造和搜索 KD树[数据集,x]
def creatAndSearch(self, points, x):
self.root = self.creatTree(points)
self.searchTree(self.root, x)
# 创造 KD树[数据集,维度]
def creatTree(self, points, col=0):
if len(points) == 0:
return None
points = sorted(points, key=lambda point: point[col])
mid = len(points) >> 1
node = KDNode(points[mid])
node.left = self.creatTree(points[0:mid], col ^ 1)
node.right = self.creatTree(points[mid + 1:len(points)], col ^ 1)
return node
# 搜索 KD树[KD树,x,维度]
def searchTree(self, tree, x, col=0):
if tree == None:
return
# 对应算法中第 1 步
if x[col] < tree.point[col]:
self.searchTree(tree.left, x, col ^ 1)
else:
self.searchTree(tree.right, x, col ^ 1)
disCurAndX = self.dis(tree.point, x)
if disCurAndX < self.nearestDis:
self.nearestDis = disCurAndX
self.nearestPoint = tree.point
# 判断目前最小圆是否与其他区域相交,即判断 |x(按轴读值)-节点(按轴读值)| < 最近的值(圆的半径)
# 对应算法中第 3 步中的 (b)
if abs(tree.point[col] - x[col]) < self.nearestDis:
if tree.point[col] < x[col]:
self.searchTree(tree.right, x, col ^ 1)
else:
self.searchTree(tree.left, x, col ^ 1)
# 两点间距离[a点,b点]
def dis(self, a, b):
return sum([(a[i] - b[i]) ** 2 for i in range(len(a))]) ** 0.5#欧氏距离
# 前序遍历 KD树(测试使用)[KD树]
def printTree(self, root):
if root != None:
print(root.point)
self.printTree(root.left)
self.printTree(root.right)
if __name__ == '__main__':
drawTool = DrawTool()
dataLoader = DataLoader("random")
kdTree = KDTree()
points = dataLoader.getData()
x = dataLoader.getRandomPoint()
kdTree.creatAndSearch(points, x)
drawTool.drawPoint(torch.tensor(points), x, kdTree.nearestPoint)