0
点赞
收藏
分享

微信扫一扫

复现经典:《统计学习方法》第1章 统计学习方法概论


本文是李航老师的《统计学习方法》[1]一书的代码复现。

备注:代码都可以在github[3]中下载。

我将陆续将代码发布在公众号“机器学习初学者”,敬请关注。

代码目录

  • 第 1 章 统计学习方法概论
  • 第 2 章 感知机
  • 第 3 章 k 近邻法
  • 第 4 章 朴素贝叶斯
  • 第 5 章 决策树
  • 第 6 章 逻辑斯谛回归
  • 第 7 章 支持向量机
  • 第 8 章 提升方法
  • 第 9 章 EM 算法及其推广
  • 第 10 章 隐马尔可夫模型
  • 第 11 章 条件随机场
  • 第 12 章 监督学习方法总结

代码参考:wzyonggege[4],WenDesi[5],火烫火烫的[6]

第 1 章 统计学习方法概论

1.统计学习是关于计算机基于数据构建概率统计模型并运用模型对数据进行分析与预测的一门学科。统计学习包括监督学习、非监督学习、半监督学习和强化学习。

2.统计学习方法三要素——模型、策略、算法,对理解统计学习方法起到提纲挈领的作用。

3.本书主要讨论监督学习,监督学习可以概括如下:从给定有限的训练数据出发, 假设数据是独立同分布的,而且假设模型属于某个假设空间,应用某一评价准则,从假设空间中选取一个最优的模型,使它对已给训练数据及未知测试数据在给定评价标准意义下有最准确的预测。

4.统计学习中,进行模型选择或者说提高学习的泛化能力是一个重要问题。如果只考虑减少训练误差,就可能产生过拟合现象。模型选择的方法有正则化与交叉验证。学习方法泛化能力的分析是统计学习理论研究的重要课题。

5.分类问题、标注问题和回归问题都是监督学习的重要问题。本书中介绍的统计学习方法包括感知机、 近邻法、朴素贝叶斯法、决策树、逻辑斯谛回归与最大熵模型、支持向量机、提升方法、EM 算法、隐马尔可夫模型和条件随机场。这些方法是主要的分类、标注以及回归方法。它们又可以归类为生成方法与判别方法。

使用最小二乘法拟和曲线

高斯于 1823 年在误差 独立同分布的假定下,证明了最小二乘方法的一个最优性质: 在所有无偏的线性估计类中,最小二乘方法是其中方差最小的!对于数据

拟合出函数

有误差,即残差:

此时 范数(残差平方和)最小时,  和 

一般的 为 次的多项式,

为参数

最小二乘法就是要找到一组   ,使得

即,求 

举例:我们用目标函数 , 加上一个正态分布的噪音干扰,用多项式去拟合【例 1.1 11 页】

import numpy as np
import scipy as sp
from scipy.optimize import leastsq
import matplotlib.pyplot as plt
%matplotlib inline

  • ps: numpy.poly1d([1,2,3]) 生成  *

# 目标函数
def real_func(x):
return np.sin(2*np.pi*x)


# 多项式
def fit_func(p, x):
f = np.poly1d(p)
return f(x)


# 残差
def residuals_func(p, x, y):
ret = fit_func(p, x) - y
return ret

# 十个点
x = np.linspace(0, 1, 10)
x_points = np.linspace(0, 1, 1000)
# 加上正态分布噪音的目标函数的值
y_ = real_func(x)
y = [np.random.normal(0, 0.1) + y1 for y1 in y_]




def fitting(M=0):
"""
M 为 多项式的次数
"""
# 随机初始化多项式参数
p_init = np.random.rand(M + 1)
# 最小二乘法
p_lsq = leastsq(residuals_func, p_init, args=(x, y))
print('Fitting Parameters:', p_lsq[0])


# 可视化
plt.plot(x_points, real_func(x_points), label='real')
plt.plot(x_points, fit_func(p_lsq[0], x_points), label='fitted curve')
plt.plot(x, y, 'bo', label='noise')
plt.legend()
return p_lsq

M=0

# M=0
p_lsq_0 = fitting(M=0)

Fitting Parameters: [0.02515259]


复现经典:《统计学习方法》第1章 统计学习方法概论_数据

M=1

# M=1
p_lsq_1 = fitting(M=1)

Fitting Parameters: [-1.50626624  0.77828571]


复现经典:《统计学习方法》第1章 统计学习方法概论_统计学习_02

M=3

# M=3
p_lsq_3 = fitting(M=3)

Fitting Parameters: [ 2.21147559e+01 -3.34560175e+01  1.13639167e+01 -2.82318048e-02]


复现经典:《统计学习方法》第1章 统计学习方法概论_github_03

M=9

# M=9
p_lsq_9 = fitting(M=9)

Fitting Parameters: [-1.70872086e+04  7.01364939e+04 -1.18382087e+05  1.06032494e+05
-5.43222991e+04 1.60701108e+04 -2.65984526e+03 2.12318870e+02
-7.15931412e-02 3.53804263e-02]


复现经典:《统计学习方法》第1章 统计学习方法概论_github_04

当 M=9 时,多项式曲线通过了每个数据点,但是造成了过拟合

正则化

结果显示过拟合, 引入正则化项(regularizer),降低过拟合

回归问题中,损失函数是平方损失,正则化可以是参数向量的 L2 范数,也可以是 L1 范数。

  • L1: regularization*abs(p)
  • L2: 0.5 * regularization * np.square(p)

regularization = 0.0001
def residuals_func_regularization(p, x, y):
ret = fit_func(p, x) - y
ret = np.append(ret,
np.sqrt(0.5 * regularization * np.square(p))) # L2范数作为正则化项
return ret

# 最小二乘法,加正则化项
p_init = np.random.rand(9 + 1)
p_lsq_regularization = leastsq(
residuals_func_regularization, p_init, args=(x, y))

plt.plot(x_points, real_func(x_points), label='real')
plt.plot(x_points, fit_func(p_lsq_9[0], x_points), label='fitted curve')
plt.plot(
x_points,
fit_func(p_lsq_regularization[0], x_points),
label='regularization')
plt.plot(x, y, 'bo', label='noise')
plt.legend()

参考资料

[1] 《统计学习方法》: https://baike.baidu.com/item/统计学习方法/10430179
[2] 黄海广: https://github.com/fengdu78
[3] github: https://github.com/fengdu78/lihang-code
[4] wzyonggege: https://github.com/wzyonggege/statistical-learning-method
[5] WenDesi: https://github.com/WenDesi/lihang_book_algorithm

复现经典:《统计学习方法》第1章 统计学习方法概论_github_05

关于本站

复现经典:《统计学习方法》第1章 统计学习方法概论_统计学习_06

举报

相关推荐

0 条评论