0
点赞
收藏
分享

微信扫一扫

weka knn算法改进与实现


本文在weka下,主要使用高斯函数加权,选取最优K值进行优化。你也可以参考网上文档,将如下文的​​KNN_lsh.java​​复制到某一目录并进行相关设置,进而在weka gui中测试改进。

文件目录:

weka knn算法改进与实现_java

/weka_test/src/cug/lsh/KNN_lsh.java 如下:

package cug.lsh;

import weka.classifiers.*;
import weka.core.*;
import java.util.*;

@SuppressWarnings("serial")
public class KNN_lsh extends Classifier {

private Instances m_Train;
private int m_kNN;

public void setM_kNN(int m_kNN) {
this.m_kNN = m_kNN;
}

public void buildClassifier(Instances data) throws Exception {
m_Train = new Instances(data);

}

public double[] distributionForInstance(Instance instance) throws Exception {

Instances instances= findNeighbors(instance, m_kNN);
return computeDistribution(instances, instance);
}

private Instances findNeighbors(Instance instance, int kNN) {
double distance;
List<HasDisInstances> neighborlist = new LinkedList<>();

for (int i = 0; i < m_Train.numInstances(); i++) {
Instance trainInstance = m_Train.instance(i);
distance = distance(instance, trainInstance);
HasDisInstances hasDisInstances=new HasDisInstances(distance,trainInstance);

if(i==0 || (i<kNN-1 && neighborlist.get(neighborlist.size()-1).distance<distance))
neighborlist.add(hasDisInstances);
else{
for (int j = 0; j < kNN && j<neighborlist.size(); j++) {
if(distance<neighborlist.get(j).distance){
neighborlist.add(j, hasDisInstances);
break;
}
}
}
}

int min=Math.min(kNN, neighborlist.size());
Instances instances=new Instances(m_Train,min);
for(int i=0;i<min;i++){
instances.add(neighborlist.get(i).instance);
}
return instances;
}

private double distance(Instance first, Instance second) {

double distance = 0;
for (int i = 0; i < m_Train.numAttributes(); i++) {
if (i == m_Train.classIndex())
continue;
if((int)first.value(i)!=(int)second.value(i)){
distance+=1;
}
// //此处修改距离计算公式
// distance+=(second.value(i)-first.value(i))*(second.value(i)-first.value(i));//欧基米德尔公式
// distance+=second.value(i)*Math.log(second.value(i)/first.value(i));最大熵
// distance+=Math.pow((second.value(i)-first.value(i)), 2)/first.value(i);//卡方距离
}
// distance=Math.sqrt(distance);
return distance;
}

private double[] computeDistribution(Instances data, Instance instance) throws Exception {

double[] prob=new double[data.numClasses()];

for (int i=0;i<data.numInstances();i++){
int classVal=(int)data.instance(i).classValue();
double x=distance(instance, data.instance(i));
prob[classVal] +=1+Math.exp(-x*x/0.18);//c=0.3
}
Utils.normalize(prob);
return prob;
}

private class HasDisInstances{
double distance;
Instance instance;
public HasDisInstances(double distance, Instance instance) {
this.distance = distance;
this.instance = instance;
}
}
}

/weka_test/src/cug/lsh/KNN_lsh_use.java(主函数) 如下:

package cug.lsh;

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

public class KNN_lsh_use {
public static void main(String[] args) throws Exception {
Instances train=DataSource.read("E:/DataLearing/data/credit-g.arff");
train.setClassIndex(train.numAttributes()-1);


int size=(int) (train.numInstances()*0.2);//构造测试集
Instances test = new Instances(train,size);
test.setClassIndex(test.numAttributes()-1);
for (int i = 0; i < size; i++) {
test.add(train.instance(i));
train.delete(i);
}

KNN_lsh classifier=new KNN_lsh();
//计算最佳k值
int optiK=0;
int prob=0;//临时变量,正确个数
for (int m_kNN = 3; m_kNN < Math.sqrt(train.numInstances())+3 && m_kNN<=20; m_kNN++) {
// long oldTime=System.currentTimeMillis();
classifier.setM_kNN(m_kNN);
classifier.buildClassifier(train);

int count=0;
for (int i = 0; i < test.numInstances(); i++){
if (classifier.classifyInstance(test.instance(i)) == test.instance(i).classValue())
count++;
}
if(count>prob){
optiK=m_kNN;
prob=count;
}
// long newTime=System.currentTimeMillis();
// System.out.println(1.0*count/test.numInstances()+","+m_kNN+","+0.001*(newTime-oldTime));
}

System.out.println(1.0*prob/test.numInstances()+","+optiK);

}
}


举报

相关推荐

0 条评论