kd-tree的实现

£神魔★判官ぃ 2022-08-05 09:00 239阅读 0赞

参考百度百科http://baike.baidu.com/link?url=JLBeRUhL6WLyp8R6TAFDD8swLfazjQnOaSXBY3AydkrVQG8XpCJ8EIh4bWpB02wQxxzPrK723ulRCzSKxkFLy\_

下面是我的实现

  1. // kd-tree.cpp : 定义控制台应用程序的入口点。
  2. //
  3. #include "stdafx.h"
  4. #include<iostream>
  5. #include<vector>
  6. #include<algorithm>
  7. using namespace std;
  8. #define KeyType double
  9. class kdtree
  10. {
  11. public:
  12. struct kdnode
  13. {
  14. kdnode*lnode, *rnode, *parent;
  15. double*value;
  16. int splitdim;//该节点在哪个维度分裂
  17. kdnode()
  18. {
  19. lnode = rnode = parent = NULL;
  20. }
  21. };
  22. private:
  23. unsigned int B;//用于构建kdb树时指定叶子中包含的数据个数,默认为2,既包含[B/2,B)个数据
  24. int dim;//维数
  25. kdnode*root;
  26. private:
  27. //选择在哪个维度分裂,合理的选择分裂可以减小树的高度
  28. int getsplitdim(vector<KeyType*>&input);
  29. //分裂数据集,left,right为分裂结果
  30. void split_dataset(vector<KeyType*>&input, int const splitdim, vector<KeyType*>&left, vector<KeyType*>&right);
  31. void create(kdnode*&node, vector<KeyType*>&input);
  32. void goback();
  33. double distance(KeyType*const aa, KeyType*const bb)
  34. {
  35. double dis = 0;
  36. for (int i = 0; i < dim; i++)
  37. dis += pow(double(aa[i] - bb[i]), double(2));
  38. return sqrt(dis);
  39. }
  40. bool UDless(int const dth, KeyType* elem1, KeyType*elem2)
  41. {
  42. return elem1[dth] < elem2[dth];
  43. }
  44. public:
  45. kdtree(int dimen = 2)
  46. {
  47. root = NULL;
  48. _ASSERTE(dimen > 1);
  49. dim = dimen;
  50. }
  51. KeyType* nearest(KeyType*const val);
  52. //void insert();
  53. void create(KeyType**&indata, int datanums);
  54. kdnode*get_root(){ return root; }
  55. ~kdtree()
  56. {
  57. if (root == NULL)
  58. return;
  59. vector<kdnode*>aa, bb;
  60. aa.push_back(root);
  61. while (!aa.empty())
  62. {
  63. kdnode*cc = aa.back();
  64. bb.push_back(cc);
  65. aa.pop_back();
  66. if (cc->lnode != NULL)
  67. aa.push_back(cc->lnode);
  68. if (cc->rnode != NULL)
  69. aa.push_back(cc->rnode);
  70. }
  71. for (int i = 0; i < bb.size(); i++)
  72. delete bb[i];
  73. };
  74. };
  75. void kdtree::create(KeyType**&indata, int datanums)
  76. {
  77. for (int i = 0; i < datanums; i++)
  78. {
  79. for (int j = 0; j < dim; j++)
  80. cout << indata[i][j] << " ";
  81. cout << endl;
  82. }
  83. root = new kdnode;
  84. vector<KeyType*>input;
  85. for (int i = 0; i < datanums; i++)
  86. input.push_back(indata[i]);
  87. create(root, input);
  88. }
  89. void kdtree::create(kdnode*&node, vector<KeyType*>&input)
  90. {
  91. if (input.size() < 1)
  92. return;
  93. int splitinfo = getsplitdim(input);
  94. node->value = input[input.size() / 2];
  95. node->splitdim = splitinfo;
  96. vector<KeyType*>left, right;
  97. //left,right为输出类型
  98. split_dataset(input, splitinfo, left, right);
  99. if (left.size() > 0)
  100. {
  101. kdnode*lnode = new kdnode;
  102. lnode->parent = node;
  103. node->lnode = lnode;
  104. create(lnode, left);
  105. }
  106. if (right.size() > 0)
  107. {
  108. kdnode*rnode = new kdnode;
  109. rnode->parent = node;
  110. node->rnode = rnode;
  111. create(rnode, right);
  112. }
  113. }
  114. void kdtree::split_dataset(vector<KeyType*>&input,
  115. int const splitdim, vector<KeyType*>&left, vector<KeyType*>&right)
  116. {
  117. int nums = input.size();
  118. left.assign(input.begin(), input.begin() + nums / 2);//将区间[first,last)的元素赋值到当前的vector容器中
  119. input.erase(input.begin(), input.begin() + nums / 2 + 1);//将区间[first,last)的元素删除
  120. right = input;
  121. }
  122. int kdtree::getsplitdim(vector<KeyType*>&input)//根据方差决定在那一个维度分裂
  123. {
  124. double maxs = -1;
  125. int splitdim;
  126. int nums = input.size();
  127. // 利用函数对象实现升降排序
  128. struct CompNameEx{
  129. CompNameEx(bool asce, int k) : asce_(asce), kk(k)
  130. {}
  131. bool operator()(KeyType*const& pl, KeyType*const& pr)
  132. {
  133. return asce_ ? pl[kk] < pr[kk] : pr[kk] < pl[kk]; // 《Eff STL》条款21: 永远让比较函数对相等的值返回false
  134. }
  135. private:
  136. bool asce_;
  137. int kk;
  138. };
  139. for (int i = 0; i < dim; i++)
  140. {
  141. double s = 0;
  142. double mean = 0;
  143. for (int j = 0; j < nums; j++)
  144. mean += input[j][i];
  145. mean /= double(nums);
  146. for (int j = 0; j < nums; j++)
  147. {
  148. s += pow(double(input[j][i] - mean), double(2));
  149. }
  150. if (s > maxs)
  151. {
  152. splitdim = i;
  153. maxs = s;
  154. }
  155. }
  156. sort(input.begin(), input.end(), CompNameEx(true, splitdim));
  157. return splitdim;
  158. }
  159. KeyType* kdtree::nearest(KeyType*const val)
  160. {
  161. if (root == NULL)
  162. return NULL;
  163. double mindis = 100000;
  164. vector<kdnode*>aa;
  165. kdnode*node = root;
  166. KeyType*tt=NULL;
  167. while (node != NULL)
  168. {
  169. aa.push_back(node);
  170. if (val[node->splitdim] > node->value[node->splitdim])
  171. node = node->rnode;
  172. else
  173. node = node->lnode;
  174. }
  175. double dis = distance(val, aa.back()->value);
  176. if (dis < mindis)
  177. {
  178. mindis = dis;
  179. tt = aa.back()->value;
  180. }
  181. aa.pop_back();
  182. while (!aa.empty())
  183. {
  184. dis = distance(val, aa.back()->value);
  185. if (dis < mindis)
  186. {
  187. mindis = dis;
  188. tt = aa.back()->value;
  189. int sd = aa.back()->splitdim;
  190. if (val[sd] < aa.back()->value[sd])
  191. {
  192. kdnode*rr = aa.back()->rnode;
  193. aa.pop_back();
  194. if (rr)
  195. aa.push_back(rr);
  196. }
  197. else
  198. {
  199. kdnode*ll = aa.back()->lnode;
  200. aa.pop_back();
  201. if (ll)
  202. aa.push_back(ll);
  203. }
  204. }
  205. else
  206. aa.pop_back();
  207. }
  208. return tt;
  209. }
  210. int _tmain(int argc, _TCHAR* argv[])
  211. {
  212. kdtree kd(2);
  213. KeyType bb[6][2] = { 2, 3, 5, 4, 9, 6, 4, 7, 8, 1, 7, 2 };// { 12, 45, 34, 12, 17, 34, 43, 889, 86, 54 };
  214. KeyType** in = new KeyType*[6];
  215. for (int i = 0; i < 6; i++)
  216. {
  217. for (int j = 0; j < 2; j++)
  218. cout << bb[i][j] << " ";
  219. cout << endl;
  220. }
  221. for (int i = 0; i < 6; i++)
  222. in[i] = bb[i];
  223. kdtree::kdnode*root = kd.get_root();
  224. kd.create(in, 6);
  225. root = kd.get_root();
  226. KeyType hh[2] = { 2, 4.5 };
  227. KeyType*n = kd.nearest(hh);
  228. delete in;
  229. system("pause");
  230. return 0;
  231. }

python里使用kd-tree

scipy.spatial.KDTree

  1. >>> from scipy import spatial
  2. >>> x, y = np.mgrid[0:5, 2:8]
  3. >>> tree = spatial.KDTree(zip(x.ravel(), y.ravel()))
  4. >>> tree.data
  5. array([[0, 2],
  6. [0, 3],
  7. [0, 4],
  8. [0, 5],
  9. [0, 6],
  10. [0, 7],
  11. [1, 2],
  12. [1, 3],
  13. [1, 4],
  14. [1, 5],
  15. [1, 6],
  16. [1, 7],
  17. [2, 2],
  18. [2, 3],
  19. [2, 4],
  20. [2, 5],
  21. [2, 6],
  22. [2, 7],
  23. [3, 2],
  24. [3, 3],
  25. [3, 4],
  26. [3, 5],
  27. [3, 6],
  28. [3, 7],
  29. [4, 2],
  30. [4, 3],
  31. [4, 4],
  32. [4, 5],
  33. [4, 6],
  34. [4, 7]])
  35. >>> pts = np.array([[0, 0], [2.1, 2.9]])
  36. >>> tree.query(pts)
  37. (array([ 2. , 0.14142136]), array([ 0, 13]))

详见源码

发表评论

表情:
评论列表 (有 0 条评论,239人围观)

还没有评论,来说两句吧...

相关阅读