0
点赞
收藏
分享

微信扫一扫

java 训练XGBoost模型并进行预测

eelq 2022-01-24 阅读 90

概述

XGBoost官方文档:​​https://xgboost.readthedocs.io/en/stable/​​

XGBoost支持多种语言:python、R、C/C++、Java、Ruby等。

平常学习时,通常使用python或R语言学习XGBoost。但生产环境一般使用Python训练模型,C++、Java读取模型进行线上预测,或直接线上训练模型并使用。

本文使用Java实时训练XGBoost模型,然后使用该模型预测数据。

一些XGBoost知识

  1. Java中使用XGBoost肯定是要加载其提供的库,本文使用maven包管理工具加载,pom内容如下:
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_2.12</artifactId>
<version>1.5.1</version>
</dependency>
  1. XGBoost使用时,需配合其自提供数据结构进行模型训练,数据预测等功能,该数据结构称为DMatrix。按官网解释,其可接受libsvm格式数据文件、系数矩阵、稠密矩阵等作为数据源。官网案例使用的libsvm格式数据文件作为数据输入,本文使用稠密矩阵格式数据。
  2. Java训练XGBoost模型时,需提供:
  • DMatrix格式的输入;
  • 训练轮数(round);
  • 模型参数(param): 如

(1)eta--可理解为防止过拟合的学习率,缺省为0.3;

(2)max_depth--树的深度,缺省为6;

(3)silent--训练时是否打印信息,0打印,1不打印,缺省为0;

(4)booster--梯度提升模型,有gblinear线性模型, gbtree树模型和dart带dropout的树模型(可能理解有偏差);缺省为gbtree;

(5)objective--目标函数,有reg:linear--线性回归;reg:logistic--逻辑回归;multi:softmax--多分类;multi:softprob--多分类等;缺省为reg:linear;

(6)eval_metric--评估指标,不同的目标函数缺省对应不同的评估指标,有rmse、logloss、mlogloss、error、merror、nuc、ndcg等;

(7)seed--随机数种子,缺省为0;

  • watches:训练模型时,评估模型在数据集上的表现,可添加训练集、测试集等(据官网如此说,可能理解有偏差);

代码

package com.gzw.java_xgboost_train_predict;

import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

import java.util.HashMap;

@SpringBootApplication
public class JavaXgboostTrainPredictApplication
{

public static void main(String[] args) throws XGBoostError
{
SpringApplication.run(JavaXgboostTrainPredictApplication.class, args);

//1. 准备数据
float[] xData = new float[]{ 1, 2, 3, 4, 5, 6, 7, 8, 9 };
float[] yData = new float[ xData.length ];
int index = 0;
for( float x : xData )
{
yData[index++] = (float)( x * 2 + 5 + Math.random()/5.0 ); //ps: random() 获取[0, 1)的数
}


//2. 训练模型
//2.1 需转换为XGBoost专用数据类型: DMatrix
int nRow = xData.length;
int nCol = 1;
float missing = 0.0f;
DMatrix trainMatrix = new DMatrix( xData, nRow, nCol, missing ); //定义输入数据
trainMatrix.setLabel( yData ); //设置标签

//2.2 准备生成模型需要的参数
HashMap<String, Object> mapParams = new HashMap<>(); //XGBoost模型的参数,优化器、树深度、超参数等
mapParams.put("eta", 0.3); //收缩步长, 可理解为XGBoost中的学习率
mapParams.put("booster", "gblinear"); //使用线性模型
mapParams.put("max_depth", 5); //树深度
mapParams.put("silent", 0); //0: 打印训练信息, 1: 不打印

HashMap<String, DMatrix> mapWatches = new HashMap<>(); //XGBoost模型训练时,可添加想要观察的模型在某数据集中的表现,如训练集、测试集等
mapWatches.put("selTrain", trainMatrix ); //观察训练集的performance(本例不提供测试集)

//2.3 训练模型
int round = 500; //训练轮数
Booster booster = XGBoost.train( trainMatrix, mapParams, round, mapWatches, null, null ); //不自定义参数, 故皆设置为null
System.out.println( booster.getAttrs() ); //打印训练的结果: rmse, 最佳训练轮数

//3. 使用训练的模型做预测
float[] xTestData = new float[]{ 10, 20, 30 };
DMatrix testMatrix = new DMatrix( xTestData, xTestData.length, 1, 0.0f );
float[][] predictResult = booster.predict( testMatrix );

System.out.println( "input=[10,20,30]时的output为:" );
for( float[] y_hat : predictResult )
{
System.out.println(y_hat[0]);
}
}

}

结果

java 训练XGBoost模型并进行预测_数据

可以看到,输入为[10,20,30]时,使用训练的模型预测结果为[25.135668, 45.13688, 65.138084],符合手动制作的训练数据的线性关系:y=2x + 5 + 噪声,表明训练成功。

备注

因训练数据与标签关系为手动指定的简单的线性关系,且只使用的一个特征,故代码训练时使用的梯度提升模型为gblinear;

实际使用中,还需分析特征数据与标签间的关系,选择合适的梯度提升模型。另外,在不熟悉XGBoost的情况下,尽量使用其提供的缺省参数。

举报

相关推荐

0 条评论