0
点赞
收藏
分享

微信扫一扫

【李航】统计学习方法--3. k近邻法(详细推导)


【李航】统计学习方法--3. k近邻法(详细推导)_分类算法



目录

  • ​​3.1 k近邻算法​​
  • ​​3.2 k近邻模型​​
  • ​​3.2.1 模型​​
  • ​​3.2.2 距离度量​​
  • ​​3.2.3 k值的选择​​
  • ​​3.2.4 分类决策规则​​
  • ​​3.3 k近邻算法的实现:kd树​​
  • ​​3.3.1 构造kd树​​
  • ​​3.3.2 搜索kd树​​
  • ​​代码​​



3.1 k近邻算法

  • k近邻法(k-nearest neighbor,k-NN):给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最近邻的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。KNN使用的模型实际上对应于特征空间的划分,没有显式的训练过程。
  • 【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_02

3.2 k近邻模型

3.2.1 模型

  • 该模型有三个基本要素:距离度量,k值的选择,分类决策规则.当这三个要素确定后,便能对于任何一个新的输入实例,给出唯一确定的分类.这里用图片说明更清楚,对于训练集中的每一个样本,距离该点比其他点更近的所有点组成一片区域,叫做单元.每个样本都拥有一个单元,所有样本的单元最终构成对整个特征空间的划分,且对每个样本而言,它的标签就是该单元内所有点的标记.这样每个单元的样本点的标签也就是唯一确定的.
  • 【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_03

三要素

3.2.2 距离度量

  • 【李航】统计学习方法--3. k近邻法(详细推导)_分类算法_04距离:特征空间中两个实例点的距离是两个实例点相优程度的反映。设输入实例【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_05距离定义为:
    【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_06
  • 欧氏距离:【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_07时的特殊情况。
    【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_08
  • 曼哈顿距离:【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_09时的特殊情况。
    【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_10
  • 切比雪夫距离:【李航】统计学习方法--3. k近邻法(详细推导)_分类算法_11时的特殊情况。
    【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_12
  • 【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_13

3.2.3 k值的选择

  • 较小的k值代表整体模型变得复杂,分类结果容易被噪声点影响,容易发生过拟合。
    较大的k值代表整体模型变得简单,容易欠拟合。
    在应用中,k值一般取一个比较小的数值,通常采用交叉验证法来选取最优的k值。

3.2.4 分类决策规则

  • k近邻法中的分类决策规则往往是多数表决,即由输入实例的k个邻近的训练实例中的多数类决定输入实例的类。
  • 多数表决规则
    如果分类的损失函数为【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_14损失函数,分类函数为
    【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_15
    那误分类的概率是
    【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_16
    对给定的实例【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_17其最近邻的【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_18个训练实例点构成集合【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_19。如果涵盖【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_19的区域类别是【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_21,那误分类率是
    【李航】统计学习方法--3. k近邻法(详细推导)_分类算法_22
    要使误分类率最小即经验风险最小,就要使【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_23最大,所以多数表决规则等价于经验风险最小化。

3.3 k近邻算法的实现:kd树

3.3.1 构造kd树

  • 输入: k 维空间数据集:
    【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_24
    其中,【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_25
    输出:kd树
  1. 开始:构造根节点。
    选取【李航】统计学习方法--3. k近邻法(详细推导)_分类算法_26为坐标轴,以训练集中的所有数据【李航】统计学习方法--3. k近邻法(详细推导)_分类算法_26
  2. 重 复
    对深度为【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_28的结点, 选择【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_29为切分坐标轴,【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_30, 以该结点区域中所有实例【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_31坐标的中位数
    作为切分点, 将区域分为两个子区域。
    生成深度为【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_32
  3. 直到两个子 区域没有实例时停止。
  • 举例
    输入:训练集:
    【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_33
    输出: kd 树
  • 【李航】统计学习方法--3. k近邻法(详细推导)_分类算法_34

  • 【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_35
    开始:选择【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_36为坐标轴,中位数为【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_37, 即【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_38为切分点, 切分整个区域
  • 【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_39

  • 再次划分区域:
    【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_40为坐标轴,选择中位数, 左边区域为【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_41, 右边区域为【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_42。故左边区域切分点为【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_43, 右边区域切分点坐标为【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_44
  • 【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_45

  • 划分左边区域:
    【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_36为坐标轴,选择中位数,上边区域为【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_41, 下边区域为【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_48。故上边 区域切分点为【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_49, 下边区域切分点坐标为【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_50
  • 【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_51

  • 划分右边区域: 以【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_36为坐标轴,选择中位数, 上边区域无实例点, 下边区域为【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_53。故 下边区域切分点坐标为【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_54
  • 【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_55

  • 最终划分结果
  • 【李航】统计学习方法--3. k近邻法(详细推导)_机器学习_56

  • kd树
  • 【李航】统计学习方法--3. k近邻法(详细推导)_分类算法_57

  • 至此算法完成

3.3.2 搜索kd树

  • 用kd树的最近邻搜索
  • 寻找 “当前最近点”
    寻找最近邻的子结点作为目标点的“当前最近点”。
  • 回溯
    以目标点和“当前最近点” 的距离沿树根部进行回溯和迭代。
  • 详细描述
    输入:已构造的 kd 树, 目标点 x
    输出: x 的最近邻
  • 寻找“当前最近点”
  • 从根结点出发, 递归访问 kd 树, 找出包含 x 的叶结点;
    以此叶结点为“当前最近点"”;
  • 回溯
  • 若该结点比 “当前最近点” 距离目标点更近, 更新“当前最近点”;
  • 当前最近点一定存在于该结点一个子结点对应的区域, 检查子结点 的父结点的另一子结点对应的区域是否有更近的点。
  • 当回退到根结点时, 搜索结束, 最后的“当前最近点” 即为 x 的最近邻点。
  • 举例
    输入: kd 树, 目 标点【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_58;
    输出:最近邻点
  • 【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_59

  • 第一次回溯
  • 【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_60

  • 第二次回溯,最近邻点:【李航】统计学习方法--3. k近邻法(详细推导)_最近邻分类算法_50
  • 【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_62

  • 如果实例点是随机分布的, kd 树搜索的平均计算复杂度是【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_63, 这里【李航】统计学习方法--3. k近邻法(详细推导)_人工智能_64是 训练实例数。 kd 树更适用于训练实例数远大于空间维数时的 k 近邻搜索。当空间维委
    接近训练实例数时,它的效率会迅速下降, 几乎接近线性扫描。

代码

import torch
import random
import matplotlib.pyplot as plt


class DrawTool():
"""画图类"""

# 画点[数据集,x点,离 x点 最近的点]
def drawPoint(self, points, x, nearestPoint):
XMax = max(points[:, 0]) # X 轴范围
YMax = max(points[:, 1]) # Y 轴范围
precision = max(XMax, YMax) // 10 + 1 # 坐标轴精度
#plt.rcParams['font.sans-serif'] = ['SimHei'] # 防止中文乱码
plt.scatter(points[:, 0], points[:, 1], label="data")
plt.scatter(x[0], x[1], c='c', marker='*', s=100, label="x(input)")
plt.scatter(nearestPoint[0], nearestPoint[1], c='r', label="nearest")
plt.xticks(torch.arange(0, XMax, precision)) # 设置 X 轴
plt.yticks(torch.arange(0, YMax, precision)) # 设置 Y 轴
plt.legend(loc='upper left')
plt.show()


class DataLoader():
""""数据加载类"""

# 初始化[creat:人造数据集,random:随机数据集]
def __init__(self, kind="creat"):
self.points = None
if (kind == "creat"):
self.x = [2, 5, 9, 4, 8, 7]
self.y = [3, 4, 6, 7, 1, 2]
elif kind == "random":
nums = random.randint(20, 40)
self.x = random.sample(range(0, 40), nums)
self.y = random.sample(range(0, 40), nums)

# 处理数据
def getData(self):
self.points = [[self.x[i], self.y[i]] for i in range(len(self.x))]
return self.points

# 得到一个与数据集不重复的点,作为 x 点
def getRandomPoint(self):
points = torch.tensor(self.points)
x, y, i = -1, -1, 0
while x == -1 or y == -1:
if x == -1 and i not in points[:, 0]:
x = i
if y == -1 and i not in points[:, 1]:
y = i
i += 2
return x, y


class KDNode():#二叉树
""""节点类"""

def __init__(self, point):
self.point = point
self.left = None
self.right = None


class KDTree():
"""KD树"""

def __init__(self):
self.root = None
self.nearestPoint = None
self.nearestDis = float('inf')

# 创造和搜索 KD树[数据集,x]
def creatAndSearch(self, points, x):
self.root = self.creatTree(points)
self.searchTree(self.root, x)

# 创造 KD树[数据集,维度]
def creatTree(self, points, col=0):
if len(points) == 0:
return None
points = sorted(points, key=lambda point: point[col])
mid = len(points) >> 1
node = KDNode(points[mid])
node.left = self.creatTree(points[0:mid], col ^ 1)
node.right = self.creatTree(points[mid + 1:len(points)], col ^ 1)
return node

# 搜索 KD树[KD树,x,维度]
def searchTree(self, tree, x, col=0):
if tree == None:
return

# 对应算法中第 1 步
if x[col] < tree.point[col]:
self.searchTree(tree.left, x, col ^ 1)
else:
self.searchTree(tree.right, x, col ^ 1)

disCurAndX = self.dis(tree.point, x)
if disCurAndX < self.nearestDis:
self.nearestDis = disCurAndX
self.nearestPoint = tree.point

# 判断目前最小圆是否与其他区域相交,即判断 |x(按轴读值)-节点(按轴读值)| < 最近的值(圆的半径)
# 对应算法中第 3 步中的 (b)
if abs(tree.point[col] - x[col]) < self.nearestDis:
if tree.point[col] < x[col]:
self.searchTree(tree.right, x, col ^ 1)
else:
self.searchTree(tree.left, x, col ^ 1)

# 两点间距离[a点,b点]
def dis(self, a, b):
return sum([(a[i] - b[i]) ** 2 for i in range(len(a))]) ** 0.5#欧氏距离

# 前序遍历 KD树(测试使用)[KD树]
def printTree(self, root):
if root != None:
print(root.point)
self.printTree(root.left)
self.printTree(root.right)


if __name__ == '__main__':
drawTool = DrawTool()
dataLoader = DataLoader("random")
kdTree = KDTree()

points = dataLoader.getData()
x = dataLoader.getRandomPoint()

kdTree.creatAndSearch(points, x)
drawTool.drawPoint(torch.tensor(points), x, kdTree.nearestPoint)

【李航】统计学习方法--3. k近邻法(详细推导)_k近邻法_65


举报

相关推荐

李航统计学习实现

0 条评论