0
点赞
收藏
分享

微信扫一扫

金融数据_Scikit-Learn梯度提升树(GradientBoostingClassifier)实例

金融数据_Scikit-Learn梯度提升树(GradientBoostingClassifier)实例

梯度提升树 (Gradient Boosting Tree):梯度提升树是一种集成学习方法, 可以通过组合多个弱学习器提高模型性能。

对于分类问题, 可以使用梯度提升决策树。

在实际应用中, 你可能需要进行一些特征工程, 确保输入特征的质量和多样性。
此外, 股票市场的预测问题本身非常复杂, 因为受到众多因素的影响, 包括市场情绪、宏观经济指标等。
过去的股票价格并不一定能够准确预测未来的价格, 因此在实际应用中, 你需要考虑到这些限制。
最好的模型可能会根据具体的数据和问题而异, 因此建议尝试多个模型, 进行交叉验证, 并评估它们在你的数据集上的性能。

当使用 Scikit-Learn 进行梯度提升树的构建时, 可以使用 GradientBoostingClassifier 类。

下面是一个简单的示例, 演示如何在 Scikit-Learn 中构建和训练梯度提升树分类器。

实例数据

本实例截取了 “湖北宜化(000422)” 2015年08月06日 - 2015年12月31日的数据。

HBYH_000422_20150806_20151231.csv

Date,Code,Open,High,Low,Close,Pre_Close,Change,Turnover_Rate,Volume,MA5,MA10
2015-12-31,'000422,7.93,7.95,7.76,7.77,7.93,-0.020177,0.015498,13915200,7.86,7.85
2015-12-30,'000422,7.86,7.93,7.75,7.93,7.84,0.011480,0.018662,16755900,7.90,7.85
2015-12-29,'000422,7.72,7.85,7.69,7.84,7.71,0.016861,0.015886,14263800,7.90,7.81
2015-12-28,'000422,8.03,8.08,7.70,7.71,8.03,-0.039851,0.030821,27672800,7.91,7.78
2015-12-25,'000422,8.03,8.05,7.93,8.03,7.99,0.005006,0.021132,18974000,7.93,7.78
2015-12-24,'000422,7.93,8.16,7.87,7.99,7.92,0.008838,0.026487,23781900,7.85,7.72
2015-12-23,'000422,7.97,8.11,7.88,7.92,7.89,0.003802,0.042360,38033600,7.80,7.69
2015-12-22,'000422,7.86,7.93,7.76,7.89,7.83,0.007663,0.026929,24178700,7.73,7.68
2015-12-21,'000422,7.59,7.89,7.56,7.83,7.63,0.026212,0.030777,27633600,7.66,7.67
2015-12-18,'000422,7.71,7.74,7.57,7.63,7.74,-0.014212,0.024764,22234900,7.62,7.71
2015-12-17,'000422,7.58,7.75,7.57,7.74,7.55,0.025166,0.028054,25188400,7.59,7.77
2015-12-16,'000422,7.57,7.62,7.53,7.55,7.55,0.000000,0.020718,18601600,7.58,7.79
2015-12-15,'000422,7.63,7.66,7.52,7.55,7.62,-0.009186,0.025902,23256600,7.64,7.78
2015-12-14,'000422,7.40,7.64,7.36,7.62,7.51,0.014647,0.021005,18860100,7.68,7.76
2015-12-11,'000422,7.65,7.70,7.41,7.51,7.67,-0.020860,0.020477,18385900,7.80,7.73
2015-12-10,'000422,7.78,7.87,7.65,7.67,7.83,-0.020434,0.019972,17931900,7.95,7.69
2015-12-09,'000422,7.76,8.00,7.75,7.83,7.77,0.007722,0.025137,22569700,8.00,7.68
2015-12-08,'000422,8.08,8.18,7.76,7.77,8.24,-0.057039,0.036696,32948200,7.92,7.66
2015-12-07,'000422,8.12,8.39,7.94,8.24,8.23,0.001215,0.064590,57993100,7.84,7.64
2015-12-04,'000422,7.85,8.48,7.80,8.23,7.92,0.039141,0.100106,89881900,7.65,7.58
2015-12-03,'000422,7.42,8.09,7.38,7.92,7.43,0.065949,0.045416,40777500,7.43,7.52
2015-12-02,'000422,7.35,7.48,7.20,7.43,7.36,0.009511,0.015968,14337600,7.37,7.49
2015-12-01,'000422,7.28,7.39,7.23,7.36,7.33,0.004093,0.012308,11050700,7.41,7.48
2015-11-30,'000422,7.18,7.36,6.95,7.33,7.11,0.030942,0.020323,18247500,7.45,7.50
2015-11-27,'000422,7.59,7.59,6.95,7.11,7.60,-0.064474,0.027673,24846700,7.51,7.52
2015-11-26,'000422,7.63,7.73,7.58,7.60,7.63,-0.003932,0.024836,22299800,7.61,7.54
2015-11-25,'000422,7.56,7.64,7.51,7.63,7.59,0.005270,0.020919,18782900,7.61,7.54
2015-11-24,'000422,7.60,7.63,7.48,7.59,7.62,-0.003937,0.014867,13348200,7.56,7.53
2015-11-23,'000422,7.59,7.72,7.55,7.62,7.61,0.001314,0.028406,25505000,7.54,7.53
2015-11-20,'000422,7.59,7.71,7.53,7.61,7.59,0.002635,0.028277,25389100,7.52,7.53
2015-11-19,'000422,7.45,7.62,7.41,7.59,7.39,0.027064,0.038638,34691700,7.47,7.52
2015-11-18,'000422,7.53,7.54,7.38,7.39,7.51,-0.015979,0.014173,12725000,7.46,7.50
2015-11-17,'000422,7.53,7.63,7.44,7.51,7.50,0.001333,0.028640,25714500,7.51,7.50
2015-11-16,'000422,7.27,7.52,7.24,7.50,7.38,0.016260,0.016230,14572000,7.52,7.46
2015-11-13,'000422,7.49,7.55,7.36,7.38,7.54,-0.021220,0.029196,26214400,7.53,7.41
2015-11-12,'000422,7.65,7.68,7.49,7.54,7.61,-0.009198,0.026501,23794800,7.56,7.40
2015-11-11,'000422,7.57,7.64,7.52,7.61,7.57,0.005284,0.026113,23445900,7.54,7.37
2015-11-10,'000422,7.51,7.61,7.45,7.57,7.55,0.002649,0.024979,22427700,7.49,7.32
2015-11-09,'000422,7.51,7.62,7.45,7.55,7.53,0.002656,0.033367,29959500,7.39,7.31
2015-11-06,'000422,7.47,7.53,7.37,7.53,7.45,0.010738,0.037058,33273100,7.29,7.27
2015-11-05,'000422,7.34,7.54,7.32,7.45,7.37,0.010855,0.040463,36330200,7.24,7.24
2015-11-04,'000422,7.10,7.38,7.07,7.37,7.05,0.045390,0.034817,31260800,7.20,7.17
2015-11-03,'000422,7.08,7.13,7.02,7.05,7.06,-0.001416,0.014938,13412400,7.15,7.10
2015-11-02,'000422,7.11,7.26,7.05,7.06,7.26,-0.027548,0.016865,15142100,7.23,7.10
2015-10-30,'000422,7.22,7.38,7.10,7.26,7.24,0.002762,0.022821,20490200,7.25,7.10
2015-10-29,'000422,7.27,7.33,7.16,7.24,7.16,0.011173,0.025726,23098500,7.23,7.08
2015-10-28,'000422,7.32,7.40,7.09,7.16,7.42,-0.035040,0.035572,31938500,7.15,7.05
2015-10-27,'000422,7.21,7.48,7.08,7.42,7.18,0.033426,0.057658,51769300,7.04,7.01
2015-10-26,'000422,7.20,7.25,7.01,7.18,7.17,0.001395,0.036840,33077800,6.98,6.96
2015-10-23,'000422,6.84,7.22,6.81,7.17,6.80,0.054412,0.047169,42351500,6.95,6.93
2015-10-22,'000422,6.68,6.81,6.64,6.80,6.65,0.022556,0.020609,18503800,6.93,6.87
2015-10-21,'000422,7.08,7.11,6.61,6.65,7.09,-0.062059,0.039388,35365300,6.96,6.85
2015-10-20,'000422,7.00,7.09,6.94,7.09,7.03,0.008535,0.024472,21972900,6.98,6.81
2015-10-19,'000422,7.09,7.13,6.92,7.03,7.08,-0.007062,0.031262,28068800,6.94,6.72
2015-10-16,'000422,6.97,7.08,6.91,7.08,6.93,0.021645,0.039632,35584700,6.91,6.66
2015-10-15,'000422,6.77,6.94,6.75,6.93,6.77,0.023634,0.031645,28412700,6.82,6.59
2015-10-14,'000422,6.87,6.94,6.74,6.77,6.89,-0.017417,0.027226,24445500,6.74,6.55
2015-10-13,'000422,6.86,6.96,6.80,6.89,6.88,0.001453,0.028704,25771900,6.64,6.51
2015-10-12,'000422,6.62,6.91,6.58,6.88,6.61,0.040847,0.037037,33254300,6.50,6.49
2015-10-09,'000422,6.54,6.65,6.45,6.61,6.54,0.010703,0.018528,16635900,6.41,6.46
2015-10-08,'000422,6.45,6.70,6.37,6.54,6.26,0.044728,0.018857,16931000,6.35,6.44
2015-09-30,'000422,6.25,6.30,6.22,6.26,6.23,0.004815,0.007327,6579090,6.35,6.43
2015-09-29,'000422,6.30,6.32,6.18,6.23,6.40,-0.026562,0.008991,8072900,6.39,6.48
2015-09-28,'000422,6.35,6.42,6.25,6.40,6.34,0.009464,0.008824,7922890,6.48,6.47
2015-09-25,'000422,6.51,6.56,6.25,6.34,6.53,-0.029096,0.012584,11298800,6.51,6.45
2015-09-24,'000422,6.48,6.56,6.45,6.53,6.45,0.012403,0.011339,10180900,6.53,6.51
2015-09-23,'000422,6.51,6.60,6.41,6.45,6.67,-0.032984,0.015920,14294100,6.52,6.54
2015-09-22,'000422,6.58,6.73,6.54,6.67,6.58,0.013678,0.023356,20970200,6.56,6.60
2015-09-21,'000422,6.34,6.61,6.29,6.58,6.44,0.021739,0.017036,15295900,6.46,6.62
2015-09-18,'000422,6.52,6.58,6.30,6.44,6.44,0.000000,0.016622,14924700,6.39,6.62
2015-09-17,'000422,6.59,6.76,6.43,6.44,6.68,-0.035928,0.019517,17523900,6.48,6.62
2015-09-16,'000422,6.21,6.76,6.17,6.68,6.15,0.086179,0.019671,17662300,6.56,6.65
2015-09-15,'000422,6.24,6.38,6.05,6.15,6.26,-0.017572,0.015338,13771200,6.64,6.66
2015-09-14,'000422,6.89,6.95,6.18,6.26,6.87,-0.088792,0.021233,18559600,6.78,6.75
2015-09-11,'000422,6.87,6.96,6.77,6.87,6.84,0.004386,0.010853,9486290,6.85,6.79
2015-09-10,'000422,6.95,7.01,6.76,6.84,7.06,-0.031161,0.017423,15229100,6.76,6.74
2015-09-09,'000422,6.90,7.09,6.86,7.06,6.88,0.026163,0.028974,25325600,6.74,6.68
2015-09-08,'000422,6.65,6.91,6.55,6.88,6.62,0.039275,0.017858,15609100,6.69,6.67
2015-09-07,'000422,6.50,6.81,6.50,6.62,6.38,0.037618,0.017850,15602600,6.72,6.75
2015-09-02,'000422,6.45,6.88,6.30,6.38,6.74,-0.053412,0.022286,19480100,6.73,6.91
2015-09-01,'000422,6.88,6.99,6.67,6.74,6.81,-0.010279,0.025829,22576700,6.72,7.12
2015-08-31,'000422,6.90,6.97,6.71,6.81,7.07,-0.036775,0.018385,16069600,6.62,7.24
2015-08-28,'000422,6.75,7.08,6.71,7.07,6.67,0.059970,0.026692,23330800,6.65,7.44
2015-08-27,'000422,6.53,6.67,6.34,6.67,6.32,0.055380,0.022455,19627900,6.78,7.59
2015-08-26,'000422,6.31,6.77,6.09,6.32,6.25,0.011200,0.029963,26190200,7.08,7.76
2015-08-25,'000422,6.40,6.77,6.25,6.25,6.94,-0.099424,0.029492,25778600,7.52,7.96
2015-08-24,'000422,7.49,7.49,6.94,6.94,7.71,-0.099870,0.036552,31949900,7.86,8.18
2015-08-21,'000422,8.00,8.11,7.60,7.71,8.17,-0.056304,0.032199,28144800,8.23,8.33
2015-08-20,'000422,8.38,8.56,8.14,8.17,8.53,-0.042204,0.031764,27764200,8.40,8.38
2015-08-19,'000422,7.73,8.57,7.72,8.53,7.96,0.071608,0.052192,45619900,8.45,8.37
2015-08-18,'000422,8.81,8.86,7.92,7.96,8.80,-0.095455,0.056179,49105500,8.39,8.32
2015-08-17,'000422,8.49,8.83,8.42,8.80,8.52,0.032864,0.048161,42096900,8.50,8.35
2015-08-14,'000422,8.48,8.65,8.43,8.52,8.44,0.009479,0.041169,35985000,8.43,8.24
2015-08-13,'000422,8.20,8.45,8.15,8.44,8.24,0.024272,0.029768,26019600,8.37,8.16
2015-08-12,'000422,8.38,8.48,8.21,8.24,8.48,-0.028302,0.035421,30960700,8.30,8.08
2015-08-11,'000422,8.41,8.68,8.32,8.48,8.49,-0.001178,0.048444,42343900,8.26,8.03
2015-08-10,'000422,8.28,8.58,8.18,8.49,8.21,0.034105,0.041268,36071600,8.20,7.92
2015-08-07,'000422,8.15,8.28,8.08,8.21,8.07,0.017348,0.025855,22599800,8.05,7.81
2015-08-06,'000422,7.88,8.21,7.80,8.07,8.03,0.004981,0.020074,17546700,7.95,7.80

探索思路

这里只是简单示例, 目的在于熟悉 Scikit-Learn 中的梯度提升树分类器使用方法, 无任何投资引导。

目标:

通过当日数值情况, 预测当日收盘涨跌, 如果 “涨跌幅(Change) >= 0”, 则用 1 表示, 如果 “涨跌幅(Change) < 0”, 则用 0 表示 (二分类标签)。

变量:

  1. 当日最高价

  2. 当日最低价

  3. 当日换手率

  4. 当日成交量

  5. 当日星期几 (星期对价格的影响)

  6. 当日 “短期均线(MA5)” 与 “长期均线(MA10)” 的关系, 如果 “MA5 > MA10”, 则用 1 表示, 如果 “MA5 = MA10”, 则用 0 表示, 如果 “MA5 < MA10”, 则用 -1 表示。

  7. 节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 “圣诞节(Christmas)” 和 “平安夜(Christmas Eve)” 做示例)。

导入 Pandas 相关模块

Pandas 是基于 NumPy 的一种工具, 该工具是为解决数据分析任务而创建的。Pandas 纳入了大量库和一些标准的数据模型, 提供了高效地操作大型数据集所需的工具。

Pandas 提供了大量能使我们快速便捷地处理数据的函数和方法。你很快就会发现, 它是使 Python 成为强大而高效的数据分析环境的重要因素之一。

import pandas as pd

导入 Scikit-Learn 相关模块

Scikit-Learn (以前称为 scikits.learn, 也称为 sklearn) 是针对 Python 编程语言的免费软件机器学习库。

它具有各种分类, 回归和聚类算法, 包括支持向量机, 随机森林, 梯度提升, K均值 和 DBSCAN, 并且旨在与 Python 数值科学库 NumPy 和 SciPy 联合使用。

from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler

使用 Pandas 读取 CSV 数据

调用 Pandas 的 .read_csv 方法读取 CSV 数据:

其中 header 参数指定 CSV 文件的表头行, 这里的 header=0 表示表头行在 1 行, 如果 header=None 则表示数据没有列索引, Pandas 则会自动加上索引。

其中 sep 参数指定 CSV 文件的分隔符, 默认情况下都是以 “,” 作为分隔符, 这里的 sep=“,” 表示指定 CSV 文件的分隔符为 “,”。

还有 dtype 参数指定 CSV 某些特定列以特定的数据类型进行读取, 例如 dtype={“Close”:float, “Volume”:int} 表示 “Close” 列以 浮点(float) 类型读取, “Volume” 列以 整数(integer) 类型读取。

PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")

输出 DataFrame 数据框:

print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
print(PDF)

输出:

[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv
          Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10
0   2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85
1   2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85
2   2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81
3   2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78
4   2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78
..         ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...
94  2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08
95  2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03
96  2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92
97  2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81
98  2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80

[99 rows x 12 columns]

转换 Pandas 中 DateFrame 各列数据类型

通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。

# 转换 Pandas 中 DateFrame 数据类型。
PDF["Date"] =          PDF["Date"].astype("datetime64[ns]")
PDF["Open"] =          PDF["Open"].astype("float64")
PDF["High"] =          PDF["High"].astype("float64")
PDF["Low"] =           PDF["Low"].astype("float64")
PDF["Close"] =         PDF["Close"].astype("float64")
PDF["Pre_Close"] =     PDF["Pre_Close"].astype("float64")
PDF["Change"] =        PDF["Change"].astype("float64")
PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")
PDF["Volume"] =        PDF["Volume"].astype("int64")
PDF["MA5"] =           PDF["MA5"].astype("float64")
PDF["MA10"] =          PDF["MA10"].astype("float64")

# 输出 Pandas 中 DataFrame 字段和数据类型。
print("[Message] Changed Pandas DataFrame Data Type:")
print(PDF.dtypes)

输出:

[Message] Changed Pandas DataFrame Data Type:
Date             datetime64[ns]
Code                     object
Open                    float64
High                    float64
Low                     float64
Close                   float64
Pre_Close               float64
Change                  float64
Turnover_Rate           float64
Volume                    int64
MA5                     float64
MA10                    float64
dtype: object

在 Pandas 的 DataFrame 中计算数据

编写 “判断股票短期均线和长期均线关系” 函数:

def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:

    if (Short_MA >= Long_MA): return  1
    if (Short_MA == Long_MA): return  0
    if (Short_MA <= Long_MA): return -1

    # ==============================================
    # End of Function.

在 Pandas 的 DataFrame 中直接计算或调用自定义函数:

# 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。
PDF["Weekday(Idx)"] =    PDF["Date"].apply(lambda X: X.weekday())
# ..................................................
# 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
PDF["Festival"] = None
for Idx in PDF.index:
    if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。
    if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas"     # -> 圣诞节。
# ..................................................
# 计算数据: 判断股票涨跌。
PDF["Rise_Fall"] =       PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))
# ..................................................
# 计算数据: 调用函数, 判断股票短期均线和长期均线关系。
PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)

# 输出计算好的 DataFrame 数据框。
print("[Message] Calculated DataFrame:")
print(PDF)

输出:

[Message] Calculated DataFrame:
         Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10  Weekday(Idx)   Festival  Rise_Fall  MA_Relationship
0  2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85             3       None          0                1
1  2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85             2       None          1                1
2  2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81             1       None          1                1
3  2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78             0       None          0                1
4  2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78             4  Christmas          1                1
..        ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...           ...        ...        ...              ...
94 2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08             2       None          0                1
95 2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03             1       None          0                1
96 2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92             0       None          1                1
97 2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81             4       None          1                1
98 2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80             3       None          1                1

[99 rows x 16 columns]

在 Pandas 的 DataFrame 中将字符串类型的特征列转换为数值 (One-Hot Encoding)

pd.get_dummies() 是 Pandas 库中用于独热编码 (One-Hot Encoding) 的函数。它的作用是将分类 (离散) 变量的每个不同取值都拓展为一个新的二进制特征 (0 或 1), 从而方便机器学习模型处理。

# 函数签名:
pd.get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, columns=None, sparse=False, drop_first=False, dtype=None)

# 参数说明:
# - data: 要进行独热编码的 DataFrame 或 Series。
# - prefix: 生成的独热编码列的前缀。
# - prefix_sep: 生成的独热编码列的前缀和原始列名之间的分隔符。
# - dummy_na: 是否为原始数据中的缺失值生成独热编码列。
# - columns: 要进行独热编码的列的名称, 如果指定, 则只对这些列进行操作。
# - drop_first: 是否删除第一个独热编码列, 以避免共线性问题。

转换 Festival 特征列为数值:

# 将字符串类型的特征列转换为数值 (独热编码)。
PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)

# 输出转换后的 DataFrame 数据框。
print("[Message] DataFrame After One-Hot Encoding:")
print(PDF)

输出:

[Message] DataFrame After One-Hot Encoding:
         Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10  Weekday(Idx)  Rise_Fall  MA_Relationship  Festival_Christmas  Festival_Christmas_Eve
0  2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85             3          0                1                   0                       0
1  2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85             2          1                1                   0                       0
2  2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81             1          1                1                   0                       0
3  2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78             0          0                1                   0                       0
4  2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78             4          1                1                   1                       0
..        ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...           ...        ...              ...                 ...                     ...
94 2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08             2          0                1                   0                       0
95 2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03             1          0                1                   0                       0
96 2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92             0          1                1                   0                       0
97 2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81             4          1                1                   0                       0
98 2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80             3          1                1                   0                       0

[99 rows x 17 columns]

提取 标签(Label)列 和 特征(Feature)列

提取 标签(Label) 列:

# 提取 标签(Label) 列。
Y = PDF["Rise_Fall"]

提取 特征(Feature) 列:

# 提取 特征(Feature) 列。
X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)

划分训练集和测试集(train_test_split) 以及 特征标准化(StandardScaler)

划分训练集和测试集(train_test_split):

# 数据集划分训练集和测试集(train_test_split)。
X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)

特征标准化(StandardScaler):

在机器学习中, fit_transform 和 transform 是用于数据预处理的常见方法, 它们的作用略有不同:

fit_transform: 该方法将同时拟合和转换数据。

  • 它会根据输入的数据计算所需的转换参数 (例如均值、标准差等), 然后将数据应用这些参数进行转换。

  • 在训练阶段, 通常使用 fit_transform 来对训练集进行拟合和转换。

  • 拟合过程会根据训练集数据计算并保存所需的转换参数, 然后将训练集数据应用这些参数进行转换。

  • 这样做的目的是确保在后续对测试集或新数据进行转换时使用相同的转换参数。

transform: 该方法仅对数据进行转换, 不进行拟合过程。

  • 它根据之前使用 fit_transform 得到的转换参数, 将这些参数应用于新的数据, 使其按照相同的转换方式进行处理。

  • 在测试阶段或对新数据应用模型时, 通常使用 transform 方法对测试集或新数据进行转换。

简而言之, fit_transform 方法用于拟合转换器并将数据进行转换, 而 transform 方法仅用于将数据按照已经拟合的转换器进行转换。

在代码中的具体应用上, 通常将 fit_transform 用于训练集的拟合和转换, 将 transform 用于测试集或新数据的转换, 以保证数据的一致性和正确的预处理操作。

# 特征标准化(StandardScaler)。
Obj_Scaler = StandardScaler()
X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)
X_Test_Scaled = Obj_Scaler.transform(X_Test)

训练 梯度提升树分类器(GradientBoostingClassifier) 模型

创建 梯度提升树分类器(GradientBoostingClassifier):

# 创建 梯度提升树分类器(GradientBoostingClassifier)。
GBC = GradientBoostingClassifier(n_estimators=100, random_state=42)

训练 梯度提升树分类器(GradientBoostingClassifier) 模型:

# 训练 梯度提升树分类器(GradientBoostingClassifier) 模型。
GBC.fit(X_Train_Scaled, Y_Train)

# Value of Return:
# +--------------------------------------------+
# |▼        GradientBoostingClassifier         |
# +--------------------------------------------+
# | GradientBoostingClassifier(random_state=42)|
# +--------------------------------------------+

使用 梯度提升树分类器(GradientBoostingClassifier) 模型预测数据

# 在测试集上进行预测。
Y_Pred = GBC.predict(X_Test_Scaled)

# 合并预测结果。
Result = X_Test.copy()
Result["Actually"] = Y_Test
Result["Prediction"] = Y_Pred

print("[Message] Prediction Results on The Test Data Set for GradientBoostingClassifier:")
print(Result)

输出:

[Message] Prediction Results on The Test Data Set for GradientBoostingClassifier:
    High   Low  Turnover_Rate    Volume  Weekday(Idx)  MA_Relationship  Festival_Christmas  Festival_Christmas_Eve  Actually  Prediction
62  6.32  6.18       0.008991   8072900             1               -1                   0                       0         0           1
40  7.54  7.32       0.040463  36330200             3                1                   0                       0         1           1
95  8.68  8.32       0.048444  42343900             1                1                   0                       0         0           1
18  8.39  7.94       0.064590  57993100             0                1                   0                       0         1           1
97  8.28  8.08       0.025855  22599800             4                1                   0                       0         1           1
84  6.77  6.09       0.029963  26190200             2               -1                   0                       0         1           0
64  6.56  6.25       0.012584  11298800             4                1                   0                       0         0           1
42  7.13  7.02       0.014938  13412400             1                1                   0                       0         0           0
10  7.75  7.57       0.028054  25188400             3               -1                   0                       0         1           1
0   7.95  7.76       0.015498  13915200             3                1                   0                       0         0           1
31  7.54  7.38       0.014173  12725000             2               -1                   0                       0         0           0
76  7.09  6.86       0.028974  25325600             2                1                   0                       0         1           1
47  7.48  7.08       0.057658  51769300             1                1                   0                       0         1           1
26  7.64  7.51       0.020919  18782900             2                1                   0                       0         1           1
44  7.38  7.10       0.022821  20490200             4                1                   0                       0         1           1
4   8.05  7.93       0.021132  18974000             4                1                   1                       0         1           1
22  7.39  7.23       0.012308  11050700             1               -1                   0                       0         1           1
12  7.66  7.52       0.025902  23256600             1               -1                   0                       0         0           1
88  8.56  8.14       0.031764  27764200             3                1                   0                       0         0           1
73  6.95  6.18       0.021233  18559600             0                1                   0                       0         0           1

使用 accuracy_score 评估模型性能

# 评估模型性能。
Accuracy = accuracy_score(Y_Test, Y_Pred)
print("Accuracy:", Accuracy)
print("\n")

# 输出分类报告。
print("Classification Report:")
print(classification_report(Y_Test, Y_Pred))

输出:

Accuracy: 0.6

Classification Report:
              precision    recall  f1-score   support

           0       0.67      0.22      0.33         9
           1       0.59      0.91      0.71        11

    accuracy                           0.60        20
   macro avg       0.63      0.57      0.52        20
weighted avg       0.62      0.60      0.54        20

完整代码

#!/usr/bin/python3
# Create By GF 2024-01-07

# 在这个示例中, 我们使用 GradientBoostingClassifier 构建梯度提升树模型。
# 为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。
# 然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。
# 最后, 我们训练模型、进行预测, 并评估模型性能。
# 请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。

import datetime
# --------------------------------------------------
import pandas as pd
# --------------------------------------------------
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler

# 编写 "判断股票短期均线和长期均线关系" 函数。
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:

    if (Short_MA >= Long_MA): return  1
    if (Short_MA == Long_MA): return  0
    if (Short_MA <= Long_MA): return -1

    # ==============================================
    # End of Function.

if __name__ == "__main__":

    PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")

    print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
    print(PDF)

    # 转换 Pandas 中 DateFrame 数据类型。
    PDF["Date"] =          PDF["Date"].astype("datetime64[ns]")
    PDF["Open"] =          PDF["Open"].astype("float64")
    PDF["High"] =          PDF["High"].astype("float64")
    PDF["Low"] =           PDF["Low"].astype("float64")
    PDF["Close"] =         PDF["Close"].astype("float64")
    PDF["Pre_Close"] =     PDF["Pre_Close"].astype("float64")
    PDF["Change"] =        PDF["Change"].astype("float64")
    PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")
    PDF["Volume"] =        PDF["Volume"].astype("int64")
    PDF["MA5"] =           PDF["MA5"].astype("float64")
    PDF["MA10"] =          PDF["MA10"].astype("float64")

    # 输出 Pandas 中 DataFrame 字段和数据类型。
    print("[Message] Changed Pandas DataFrame Data Type:")
    print(PDF.dtypes)

    # 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。
    PDF["Weekday(Idx)"] =    PDF["Date"].apply(lambda X: X.weekday())
    # ..................................................
    # 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
    PDF["Festival"] = None
    for Idx in PDF.index:
        if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。
        if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas"     # -> 圣诞节。
    # ..................................................
    # 计算数据: 判断股票涨跌。
    PDF["Rise_Fall"] =       PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))
    # ..................................................
    # 计算数据: 调用函数, 判断股票短期均线和长期均线关系。
    PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)

    # 输出计算好的 DataFrame 数据框。
    print("[Message] Calculated DataFrame:")
    print(PDF)

    # 将字符串类型的特征列转换为数值 (独热编码)。
    PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)

    # 输出转换后的 DataFrame 数据框。
    print("[Message] DataFrame After One-Hot Encoding:")
    print(PDF)

    # 提取 标签(Label) 列。
    Y = PDF["Rise_Fall"]

    # 提取 特征(Feature) 列。
    X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)

    # 数据集划分训练集和测试集(train_test_split)。
    X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)

    # 特征标准化(StandardScaler)。
    Obj_Scaler = StandardScaler()
    X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)
    X_Test_Scaled = Obj_Scaler.transform(X_Test)

    # 创建 梯度提升树分类器(GradientBoostingClassifier)。
    GBC = GradientBoostingClassifier(n_estimators=100, random_state=42)

    # 训练 梯度提升树分类器(GradientBoostingClassifier) 模型。
    GBC.fit(X_Train_Scaled, Y_Train)

    # Value of Return:
    # +--------------------------------------------+
    # |▼        GradientBoostingClassifier         |
    # +--------------------------------------------+
    # | GradientBoostingClassifier(random_state=42)|
    # +--------------------------------------------+

    # 在测试集上进行预测。
    Y_Pred = GBC.predict(X_Test_Scaled)

    # 合并预测结果。
    Result = X_Test.copy()
    Result["Actually"] = Y_Test
    Result["Prediction"] = Y_Pred

    print("[Message] Prediction Results on The Test Data Set for GradientBoostingClassifier:")
    print(Result)

    # 评估模型性能。
    Accuracy = accuracy_score(Y_Test, Y_Pred)
    print("Accuracy:", Accuracy)
    print("\n")

    # 输出分类报告。
    print("Classification Report:")
    print(classification_report(Y_Test, Y_Pred))

其它

在这个示例中, 我们使用 GradientBoostingClassifier 构建梯度提升树模型。

为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。

然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。

最后, 我们训练模型、进行预测, 并评估模型性能。

请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。

总结

以上就是关于 金融数据 Scikit-Learn梯度提升树(GradientBoostingClassifier)实例 的全部内容。

更多内容可以访问我的代码仓库:

https://gitee.com/goufeng928/public

https://github.com/goufeng928/public

举报

相关推荐

0 条评论