1、三维空间点代码:
1、kdtree.h
/**
* @file kdtree.h
* @brief Thisis a brief description.
* @author dongjian
* @par Copyright (c):
* All Rights Reserved
* @date 2018:04:24
* @note mattersneeding attention
*/
#ifndef _5383BD42_370E_4C00_A25E_AD4403E5656A
#define _5383BD42_370E_4C00_A25E_AD4403E5656A
#include <vector>
#include <queue>
#include <iostream>
namespace kdtree
{
class Point //空间中的三维点
{
public:
double x;
double y;
double z;
// 末尾const的作用是 不会改变本类对象中的实例成员变量
double Distance(const Point& p) const;
// 有参构造函数。
Point(double ox, double oy, double oz);
void printPoint(){
std::cout << "最近点:("<< this->x << "," << this->y << "," << this->z << ")"<<std::endl;
}
};
class BBox
{
public:
double xmin, ymin, zmin;
double xmax, ymax, zmax;
bool Contains(const Point& p) const;
bool Contains(const BBox& bbox) const;
// 相交的意思
bool Intersects(const BBox& bbox) const;
BBox(double x0, double x1, double y0, double y1, double z0, double z1);
// 静态函数,返回值是BBox
static BBox UniverseBBox();
double ShortestDistance(const Point& p) const;
};
enum Dimension
{
X = 0,
Y = 1,
Z = 2
};
class Node
{
public:
Node* left; //left child
Node* right; //right child
int begin; //start index [close
int end; //end index (open
Dimension dimension; //cut dimension
double pivot; //cut value
Node(int b, int e, Dimension dim, double piv);
bool IsLeaf() const;
// 左子叶信封 什么鬼?
BBox LeftSubTreeEnvelope(const BBox& current_entext) const;
BBox RightSubTreeEnvelope(const BBox& current_entext) const;
};
class KDTree
{
public:
// 设置输入,构建树
void SetInput(const std::vector<Point>& inputs);
// 搜索
void RangeSearch(const BBox& bbox, std::vector<int>& indices) const ;
int NearestNeighbourSearch(const Point& searchpoint) const ;
void NearestNeighbourSearch(const Point& searchpoint, int k, std::vector<int>& indices) const;
static int _leaf_max_size;
private:
const std::vector<Point>* pData; // 2维数组
std::vector<int> indices;
Node* pRoot;
private:
struct comparepair
{
bool operator()(const std::pair<int, double>&, const std::pair<int, double>&);
};
typedef std::priority_queue<std::pair<int, double>, std::vector<std::pair<int, double>>, comparepair> index_distance;
Node* DivideTree(const std::vector<Point>& inputs,int start,int end);
Dimension SelectDimension(const std::vector<Point>& inputs, int start, int end);
int Partition(int start, int end, Dimension dim);
double GetAt(int index, Dimension dim);
void _range_search(const BBox& search_bbox, Node* pNode, const BBox& current_extent,std::vector<int>& indices) const ;
int _nearest_neighbour(const Point& searchPoint, Node* pNode, const BBox& current_extent, double& current_shorest_dis) const ;
void _nearest_neighbour(const Point& searchPoint, int k, index_distance& ins, Node* pNode, const BBox& current_extent, double& current_shorest_dis) const ;
};
}
#endif //_5383BD42_370E_4C00_A25E_AD4403E5656A
2、kdtree.cpp 含demo实现
#include "kdtree.h"
#include <algorithm>
#include <iostream>
#include <math.h>
using namespace std;
using namespace kdtree;
Point::Point(double ox, double oy, double oz) :
x(ox),
y(oy),
z(oz)
{}
double Point::Distance(const Point& p) const
{
double dx = p.x - x;
double dy = p.y - y;
double dz = p.z - z;
return sqrt(dx*dx + dy*dy + dz*dz);
}
BBox Node::LeftSubTreeEnvelope(const BBox& current_entext) const
{
BBox leftRegion(current_entext);
switch (dimension)
{
case X:
leftRegion.xmax = pivot;
break;
case Y:
leftRegion.ymax = pivot;
break;
case Z:
leftRegion.zmax = pivot;
break;
}
return leftRegion;
}
BBox Node::RightSubTreeEnvelope(const BBox& current_entext) const
{
BBox rightRegion(current_entext);
switch (dimension)
{
case X:
rightRegion.xmin = pivot;
break;
case Y:
rightRegion.ymin = pivot;
break;
case Z:
rightRegion.zmin = pivot;
break;
}
return rightRegion;
}
bool BBox::Contains(const BBox& bbox) const
{
if (bbox.xmin < xmin) return false;
if (bbox.xmax > xmax) return false;
if (bbox.ymin < ymin) return false;
if (bbox.ymax < ymax) return false;
if (bbox.zmin < zmin) return false;
if (bbox.zmax > zmax) return false;
return true;
}
bool BBox::Contains(const Point& p) const
{
return p.x >= xmin && p.x <= xmax && p.y >= ymin && p.y <= ymax
&& p.z >= zmin && p.z <= zmax;
}
bool BBox::Intersects(const BBox& bbox) const
{
if (bbox.xmin > xmax || bbox.xmax < xmin) return false;
if (bbox.ymin > xmax || bbox.ymax < ymin) return false;
if (bbox.zmin > zmax || bbox.zmax < zmin) return false;
return true;
}
BBox BBox::UniverseBBox()
{
double DOUBLE_MAX = std::numeric_limits<double>::max();
return BBox(-DOUBLE_MAX, DOUBLE_MAX, -DOUBLE_MAX,
DOUBLE_MAX, -DOUBLE_MAX, DOUBLE_MAX);
}
BBox::BBox(double x0, double x1, double y0, double y1, double z0, double z1) :
xmin(x0),
xmax(x1),
ymin(y0),
ymax(y1),
zmin(z0),
zmax(z1)
{}
double BBox::ShortestDistance(const Point& p) const
{
double dx = xmin - p.x > 0 ? xmin - p.x : (p.x - xmax > 0 ? p.x - xmax : 0);
double dy = ymin - p.y > 0 ? ymin - p.y : (p.y - ymax > 0 ? p.y - ymax : 0);
double dz = zmin - p.z > 0 ? zmin - p.z : (p.z - zmax > 0 ? p.z - zmax : 0);
return sqrt(dx*dx + dy*dy + dz*dz);
}
Node::Node(int b, int e, Dimension dim, double piv) :
begin(b),
end(e),
dimension(dim),
pivot(piv),
left(nullptr),
right(nullptr)
{
}
bool Node::IsLeaf() const
{
return left == nullptr && right == nullptr;
}
// 选择维度?
/**
* @brief
*
* @param inputs 输入原始数据
* @param start 起始索引
* @param end
* @return Dimension 输入的数据中,在x y z轴上哪个跨度大,就输出哪个枚举类型
*/
Dimension KDTree::SelectDimension(const std::vector<Point>& inputs, int start, int end)
{
struct cmpobj
{
Dimension dim;
cmpobj(Dimension d) :
dim(d){}
bool operator()(const Point& p0, const Point p1) const
{
if (dim == X)
return p0.x < p1.x;
if (dim == Y)
return p0.y < p1.y;
if (dim == Z)
return p0.z < p1.z;
return false;
}
};
// span存的是最大-最小值
double span[3];
// 该pair 的first指向[first,last)范围内最小值元素,second指向最大值元素。
auto pair = std::minmax_element(inputs.begin() + start, inputs.begin() + end,cmpobj(X));
span[0] = pair.second->x - pair.first->x;
pair = std::minmax_element(inputs.begin() + start, inputs.begin() + end, cmpobj(Y));
span[1] = pair.second->y - pair.first->y;
pair = std::minmax_element(inputs.begin() + start, inputs.begin() + end, cmpobj(Z));
span[2] = pair.second->z - pair.first->z;
// 获得的索引是 span里面最大元素与第一个元素之间的距离,也就是获得最大元素的索引。
auto index = std::distance(span, std::max_element(span, span + 3));
return static_cast<Dimension>(index);
}
double KDTree::GetAt(int index, Dimension dim)
{
auto p = (*pData)[indices[index]];
return dim == X ? p.x : (dim == Y ? p.y : p.z);
}
int KDTree::Partition(int start, int end, Dimension dimension)
{
int size = end - start;
if (size <= 0)
{
cout << "a serious error occurs " << start << "\t" << endl;
return -1;
}
struct cmpobj
{
Dimension dim;
const std::vector<Point>* pData;
bool operator()(int i, int j) const
{
if (X == dim)
return (*pData)[i].x < (*pData)[j].x;
if (Y == dim)
return (*pData)[i].y < (*pData)[j].y;
if (Z == dim)
return (*pData)[i].z < (*pData)[j].z;
return true;
}
cmpobj(Dimension dimension, const std::vector<Point>* pInputData):
dim(dimension),
pData(pInputData)
{}
};
int median = start +size / 2;
std::nth_element(indices.begin() + start, indices.begin()+median , indices.begin() + end, cmpobj(dimension, pData));
return median;
}
int KDTree::NearestNeighbourSearch(const Point& searchpoint) const
{
double shortestDistance = std::numeric_limits<double>::max();
return _nearest_neighbour(searchpoint, pRoot, BBox::UniverseBBox(), shortestDistance);
}
void KDTree::NearestNeighbourSearch(const Point& searchpoint, int k, std::vector<int>& indices) const
{
double shortestDistance = std::numeric_limits<double>::max();
indices.clear();
index_distance neighbours;
_nearest_neighbour(searchpoint, k, neighbours, pRoot, BBox::UniverseBBox(), shortestDistance);
while (!neighbours.empty())
{
indices.push_back(neighbours.top().first);
neighbours.pop();
}
}
bool KDTree::comparepair::operator()(const std::pair<int, double>& p0, const std::pair<int, double>& p1)
{
return p0.second < p1.second;
}
void KDTree::_nearest_neighbour(const Point& searchPoint, int k, index_distance& ins,
Node* pNode, const BBox& current_extent, double& current_shorest_dis) const
{
double min_shortest_distance = current_extent.ShortestDistance(searchPoint);
if (min_shortest_distance >= current_shorest_dis)
return;
if (pNode->IsLeaf())
{
for (auto i = pNode->begin; i < pNode->end; ++i)
{
double distance = (*pData)[indices[i]].Distance(searchPoint);
if (ins.size() < k)
{
ins.push(pair<int, double>(indices[i], distance)); //add element
if (k == ins.size())
current_shorest_dis = ins.top().second;
}
else
{
if (distance < current_shorest_dis)
{
ins.pop();
ins.push(pair<int, double>(indices[i], distance));//add element
current_shorest_dis = ins.top().second;
}
}
}
}
else
{
BBox leftRegion = pNode->LeftSubTreeEnvelope(current_extent);
BBox rightRegion = pNode->RightSubTreeEnvelope(current_extent);
double dis_to_left = leftRegion.ShortestDistance(searchPoint);
double dis_to_right = rightRegion.ShortestDistance(searchPoint);
if (dis_to_left < dis_to_right)
{
_nearest_neighbour(searchPoint,k,ins,pNode->left,leftRegion,current_shorest_dis);
_nearest_neighbour(searchPoint, k,ins,pNode->right, rightRegion, current_shorest_dis);
}
else
{
_nearest_neighbour(searchPoint,k,ins, pNode->right, rightRegion, current_shorest_dis);
_nearest_neighbour(searchPoint,k,ins, pNode->left, leftRegion, current_shorest_dis);
}
}
}
int KDTree::_nearest_neighbour(const Point& searchPoint, Node* pNode, const BBox& current_extent, double& current_shorest_dis) const
{
double min_shortest_distance = current_extent.ShortestDistance(searchPoint);
if (min_shortest_distance >= current_shorest_dis)
return -1;
if (pNode->IsLeaf())
{
int shortestindex =-1;
for (auto i = pNode->begin; i < pNode->end;++i)
{
double distance = (*pData)[indices[i]].Distance(searchPoint);
if (distance < current_shorest_dis)
{
shortestindex = indices[i];
current_shorest_dis = distance;
}
}
return shortestindex;
}
else
{
BBox leftRegion = pNode->LeftSubTreeEnvelope(current_extent);
BBox rightRegion = pNode->RightSubTreeEnvelope(current_extent);
double dis_to_left = leftRegion.ShortestDistance(searchPoint);
double dis_to_right = rightRegion.ShortestDistance(searchPoint);
if (dis_to_left < dis_to_right)
{
int left = _nearest_neighbour(searchPoint, pNode->left, leftRegion, current_shorest_dis);
int right = _nearest_neighbour(searchPoint, pNode->right, rightRegion, current_shorest_dis);
return right == -1 ? left : right;
}
else
{
int right = _nearest_neighbour(searchPoint, pNode->right, rightRegion, current_shorest_dis);
int left = _nearest_neighbour(searchPoint, pNode->left, leftRegion, current_shorest_dis);
return left == -1 ? right : left;
}
return -1;
}
}
int KDTree::_leaf_max_size = 2; //原来是15
void KDTree::SetInput(const std::vector<Point>& inputs)
{
pData = &inputs;// 指针的首地址 指向了inputs的地址
indices.resize(inputs.size());
for (int i = 0, n = inputs.size(); i < n; ++i)
indices[i] = i;// 1 2 3 4 5 6 ....
// Node类型的指针
pRoot = DivideTree(inputs, 0, inputs.size()); //划分树
}
Node* KDTree::DivideTree(const std::vector<Point>& inputs, int start, int end)
{
//cout << "build " << start << "\t" << end << endl;
int size = end - start;
if (size <= 0)
return nullptr;
Dimension dim = SelectDimension(inputs, start, end); // 选择划分维度,是按X轴划分还是y轴还是z轴
// 中位数
int median = Partition(start, end, dim); //获取中位数
Node* pNode = new Node(start, end, dim, GetAt(median, dim));
//递归终止条件是输入的点的个数小于叶子节点最大点个数
if (size > _leaf_max_size)
{
pNode->left = DivideTree(inputs, start, median);
pNode->right = DivideTree(inputs, median, end);
}
return pNode;
}
void KDTree::RangeSearch(const BBox& bbox, std::vector<int>& indices) const
{
BBox universe_bbox = BBox::UniverseBBox();
_range_search(bbox, pRoot, universe_bbox, indices);
}
void KDTree::_range_search(const BBox& search_bbox, Node* pNode, const BBox& current_extent, std::vector<int>& ins) const
{
if (nullptr == pNode)
return;
if (pNode->IsLeaf())
{
for (int i = pNode->begin; i < pNode->end; ++i)
{
const Point& p = (*pData)[indices[i]];
if (search_bbox.Contains(p))
ins.push_back(indices[i]);
}
}
else
{
//trim bounding box
BBox leftRegion = pNode->LeftSubTreeEnvelope(current_extent);
if (search_bbox.Contains(leftRegion))
{
for (int i = pNode->left->begin; i < pNode->left->end; ++i)
ins.push_back(indices[i]);
}
else if (search_bbox.Intersects(leftRegion))
{
_range_search(search_bbox, pNode->left, leftRegion, ins);
}
BBox rightRegion = pNode->RightSubTreeEnvelope(current_extent);
if (search_bbox.Contains(rightRegion))
{
for (int i = pNode->right->begin; i < pNode->right->end; ++i)
ins.push_back(indices[i]);
}
else if (search_bbox.Intersects(rightRegion))
{
_range_search(search_bbox, pNode->right, rightRegion, ins);
}
}
}
int main(int argc, char *argv[]){
std::vector<kdtree::Point> input;
// 1、输入数据
for (int i = 0; i < 10; i++)
{
double x,y,z;
std::cin>>x>>y>>z;
std::cout<<x<<y<<z<<std::endl;
kdtree::Point p(x, y, z);
input.push_back(p);
}
//2、构建树
KDTree kdtree;
kdtree.SetInput(input);
//3、最近邻搜索
//创建寻找的点
int k =3;
Point p1(1.2 , 2.1, 3.2);
std::vector<int> indices;//索引
kdtree.NearestNeighbourSearch(p1, k, indices);
//4、打印K个最近点
for(int i = 0; i<k; i++ ){
input[indices[i]].printPoint();
std::cout << "距离是:" << p1.Distance(input[indices[i]]) << std::endl;
}
return 0;
}
结果:
2、二维空间代码
1、kdtree_new.h
//
// KDTree.h
// Test
//
// Created by xiuzhu on 2021/7/13.
//
#ifndef kdtree_new_h
#define kdtree_new_h
#include <vector>
using namespace std;
//template <class T>
class KDTree {
private:
int key;
vector<double> root;//树及其子树的根节点
KDTree *parent;
KDTree *left_child;
KDTree *right_child;
public:
KDTree();
~KDTree();
bool is_empty();//判断KD树是否为空
bool is_leaf();//判断树是否只有一个叶子节点
bool is_root();//判断是否是树的根节点
bool is_left();//判断该子kd树的根结点是否是其父kd树的左结点
bool is_right();//判断该子kd树的根结点是否是其父kd树的右结点
vector<vector<double> > Transpose(const vector<vector<double> > &Matrix);//坐标转换
double findMiddleValue(vector<double> vec);//查找中指
void buildKdTree(KDTree* tree, vector<vector<double> > data, unsigned depth);//构建kd树
void printKdTree(KDTree* tree, unsigned int depth);//逐层打印KD树
double measureDistance(vector<double> point1, vector<double> point2, unsigned method);//计算空间中两个点的距离
vector<double> searchNearestNeighbor(vector<double> goal, KDTree *tree);在kd树tree中搜索目标点goal的最近邻
};
#endif /* KDTree_h */
2、kdtree_new.cpp
//
// KDTree.cpp
// Test
//
// Created by xiuzhu on 2021/7/13.
//
#include "kdtree_new.h"
#include <stdio.h>
#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <cmath>
using namespace std;
KDTree::KDTree():parent(nullptr),left_child(nullptr),right_child(nullptr){}
KDTree::~KDTree(){}
bool KDTree::is_empty()
{
return root.empty();
}
bool KDTree::is_leaf()
{
return (!root.empty()) && right_child == nullptr && left_child == nullptr;
}
bool KDTree::is_root()
{
return (!is_empty()) && parent == nullptr;
}
bool KDTree::is_left()
{
return parent->left_child->root == root;
}
bool KDTree::is_right()
{
return parent->right_child->root == root;
}
//用于转换坐标
vector<vector<double> > KDTree::Transpose(const vector<vector<double>> &Matrix)
{
unsigned row = Matrix.size();
unsigned col = Matrix[0].size();
vector<vector<double> > Trans(col,vector<double>(row,0));
for (unsigned i = 0; i < col; ++i)
{
for (unsigned j = 0; j < row; ++j)
{
Trans[i][j] = Matrix[j][i];
}
}
return Trans;
}
//在不同的坐标轴上寻找中值
double KDTree::findMiddleValue(vector<double> vec)
{
sort(vec.begin(),vec.end());
auto pos = vec.size() / 2;
return vec[pos];
}
void KDTree::buildKdTree(KDTree *tree, vector<vector<double>> data, unsigned int depth)
{
//样本的数量
unsigned long samplesNum = data.size();
//终止条件
if (samplesNum == 0)
{
return;
}
if (samplesNum == 1)
{
tree->root = data[0];
return;
}
//样本的维度
unsigned long k = data[0].size();//坐标轴个数
vector<vector<double> > transData = Transpose(data);
//选择切分属性
unsigned splitAttribute = depth % k;
vector<double> splitAttributeValues = transData[splitAttribute];
//选择切分值
double splitValue = findMiddleValue(splitAttributeValues);
//cout << "splitValue" << splitValue << endl;
// 根据选定的切分属性和切分值,将数据集分为两个子集
vector<vector<double> > subset1;
vector<vector<double> > subset2;
for (unsigned i = 0; i < samplesNum; ++i)
{
if (splitAttributeValues[i] == splitValue && tree->root.empty())
tree->root = data[i];
else
{
if (splitAttributeValues[i] < splitValue)
subset1.push_back(data[i]);
else
subset2.push_back(data[i]);
}
}
//子集递归调用buildKdTree函数
tree->left_child = new KDTree;
tree->left_child->parent = tree;
tree->right_child = new KDTree;
tree->right_child->parent = tree;
buildKdTree(tree->left_child, subset1, depth + 1);
buildKdTree(tree->right_child, subset2, depth + 1);
}
void KDTree::printKdTree(KDTree *tree, unsigned int depth)//打印
{
for (unsigned i = 0; i < depth; ++i)
cout << "\t";
for (vector<double>::size_type j = 0; j < tree->root.size(); ++j)
cout << tree->root[j] << ",";
cout << endl;
if (tree->left_child == nullptr && tree->right_child == nullptr )//叶子节点
return;
else //非叶子节点
{
if (tree->left_child != nullptr)
{
for (unsigned i = 0; i < depth + 1; ++i)
cout << "\t";
cout << " left:";
printKdTree(tree->left_child, depth + 1);
}
cout << endl;
if (tree->right_child != nullptr)
{
for (unsigned i = 0; i < depth + 1; ++i)
cout << "\t";
cout << "right:";
printKdTree(tree->right_child, depth + 1);
}
cout << endl;
}
}
//计算空间中两个点的距离
double KDTree::measureDistance(vector<double> point1, vector<double> point2, unsigned int method){
if (point1.size() != point2.size())
{
cerr << "Dimensions don't match!!" ;
exit(1);
}
switch (method)
{
case 0://欧氏距离
{
double res = 0;
for (vector<double>::size_type i = 0; i < point1.size(); ++i)
{
res += pow((point1[i] - point2[i]), 2);
}
return sqrt(res);
}
case 1://曼哈顿距离
{
double res = 0;
for (vector<double>::size_type i = 0; i < point1.size(); ++i)
{
res += abs(point1[i] - point2[i]);
}
return res;
}
default:
{
cerr << "Invalid method!!" << endl;
return -1;
}
}
}
//在kd树tree中搜索目标点goal的最近邻
//输入:目标点;已构造的kd树
//输出:目标点的最近邻
vector<double> KDTree::searchNearestNeighbor(vector<double> goal, KDTree *tree)
{
/*第一步:在kd树中找出包含目标点的叶子结点:从根结点出发,
递归的向下访问kd树,若目标点的当前维的坐标小于切分点的
坐标,则移动到左子结点,否则移动到右子结点,直到子结点为
叶结点为止,以此叶子结点为“当前最近点”
*/
unsigned long k = tree->root.size();//计算出数据的维数
unsigned d = 0;//维度初始化为0,即从第1维开始
KDTree* currentTree = tree;
vector<double> currentNearest = currentTree->root;
while(!currentTree->is_leaf())
{
unsigned index = d % k;//计算当前维
if (currentTree->right_child->is_empty() || goal[index] < currentNearest[index])
{
currentTree = currentTree->left_child;
}
else
{
currentTree = currentTree->right_child;
}
++d;
}
currentNearest = currentTree->root;
/*第二步:递归地向上回退, 在每个结点进行如下操作:
(a)如果该结点保存的实例比当前最近点距离目标点更近,则以该例点为“当前最近点”
(b)当前最近点一定存在于某结点一个子结点对应的区域,检查该子结点的父结点的另
一子结点对应区域是否有更近的点(即检查另一子结点对应的区域是否与以目标点为球
心、以目标点与“当前最近点”间的距离为半径的球体相交);如果相交,可能在另一
个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着递归进行最
近邻搜索;如果不相交,向上回退*/
//当前最近邻与目标点的距离
double currentDistance = measureDistance(goal, currentNearest, 0);
//如果当前子kd树的根结点是其父结点的左孩子,则搜索其父结点的右孩子结点所代表
//的区域,反之亦反
KDTree* searchDistrict;
if (currentTree->is_left())
{
if (currentTree->parent->right_child == nullptr)
searchDistrict = currentTree;
else
searchDistrict = currentTree->parent->right_child;
}
else
{
searchDistrict = currentTree->parent->left_child;
}
//如果搜索区域对应的子kd树的根结点不是整个kd树的根结点,继续回退搜索
while (searchDistrict->parent != nullptr)
{
//搜索区域与目标点的最近距离
double districtDistance = abs(goal[(d+1)%k] - searchDistrict->parent->root[(d+1)%k]);
//如果“搜索区域与目标点的最近距离”比“当前最近邻与目标点的距离”短,表明搜索
//区域内可能存在距离目标点更近的点
if (districtDistance < currentDistance )//&& !searchDistrict->isEmpty()
{
double parentDistance = measureDistance(goal, searchDistrict->parent->root, 0);
if (parentDistance < currentDistance)
{
currentDistance = parentDistance;
currentTree = searchDistrict->parent;
currentNearest = currentTree->root;
}
if (!searchDistrict->is_empty())
{
double rootDistance = measureDistance(goal, searchDistrict->root, 0);
if (rootDistance < currentDistance)
{
currentDistance = rootDistance;
currentTree = searchDistrict;
currentNearest = currentTree->root;
}
}
if (searchDistrict->left_child != nullptr)
{
double leftDistance = measureDistance(goal, searchDistrict->left_child->root, 0);
if (leftDistance < currentDistance)
{
currentDistance = leftDistance;
currentTree = searchDistrict;
currentNearest = currentTree->root;
}
}
if (searchDistrict->right_child != nullptr)
{
double rightDistance = measureDistance(goal, searchDistrict->right_child->root, 0);
if (rightDistance < currentDistance)
{
currentDistance = rightDistance;
currentTree = searchDistrict;
currentNearest = currentTree->root;
}
}
}//end if
if (searchDistrict->parent->parent != nullptr)
{
searchDistrict = searchDistrict->parent->is_left()?
searchDistrict->parent->parent->right_child:
searchDistrict->parent->parent->left_child;
}
else
{
searchDistrict = searchDistrict->parent;
}
++d;
}//end while
return currentNearest;
}
const int data[6][2]={{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
int main()
{
vector<vector<double> > train(6, vector<double>(2, 0)); // 初始化为6个 值为(2,0)的坐标
for (unsigned i = 0; i < 6; ++i)
for (unsigned j = 0; j < 2; ++j)
train[i][j] = data[i][j];
KDTree* kdTree = new KDTree();
kdTree->buildKdTree(kdTree, train, 0);
kdTree->printKdTree(kdTree, 0);
vector<double> goal;
goal.push_back(4);
goal.push_back(6);
vector<double> nearestNeighbor = kdTree->searchNearestNeighbor(goal, kdTree);
vector<double>::iterator beg = nearestNeighbor.begin();
cout << "The nearest neighbor is: ";
while(beg != nearestNeighbor.end()) cout << *beg++ << ",";
cout << endl;
return 0;
}