k近邻算法matlab实现_k近邻算法

Dear 丶 2023-01-01 09:53 328阅读 0赞

k 近邻法 (k-NN) 是一种基于实例的学习方法,无法转化为对参数空间的搜索问题(参数最优化 问题)。它的特点是对特征空间进行搜索。除了k近邻法,本章还对以下几个问题进行较深入的讨 论:

  • 切比雪夫距离 的计算
  • “近似误差” 与“估计误差” 的含义
  • k-d树搜索算法图解

一、算法

输入:训练集 为实例特征向 为实例的类别,

输出:实例 所属的类

设在给定距离度量下,涵盖最近k个点的邻域为

其中示性函数 为真为假

寻找使得函数 取得最大值的变量 也就是说, 看看距离 最近的k个点里面哪一类别最多,以此作为输出。

二、模型

根据模型的分类, k-NN模型属于非概率模型。

观察 可发现它与感知机不同的之处, 作为决策函数, 它并不需要任何未知参数(感知机需要确定W和b),直接从训练集的数据得到输出。

1. 距离度量

k-NN的基本思想是,特征空间中的距离反映了两个点的相似程度, 因此 “距离” 是作出分类判断 的基本依据。向量空间 的距离有多种度量方式:

(1) 不同距离度量

一般形式是闵可夫斯基距离( 范数):

当p=1时, 称为曼哈顿距离( 范数):

当p=2时,称为欧几里得距离( 范数),也就是最常用的距离::

  1. import math
  2. from itertools import combinations
  3. def L(x, y, p=2):
  4. # x1 = [1, 1], x2 = [5,1]
  5. if len(x) == len(y) and len(x) > 1:
  6. sum = 0
  7. for i in range(len(x)):
  8. sum += math.pow(abs(x[i] - y[i]), p)
  9. return math.pow(sum, 1 / p)
  10. else:
  11. return 0

下图表示平面上A、B两点之间不同的距离:a0b497bdbfb6a0425ad0100e87baf1a7.png

  • 只允许沿着坐标轴方向前进, 就像曼哈顿街区的出租车走过的距离
  • 两点之间直线最短, 通过勾股定理计算斜边的距离
  • 只剩下一个维度, 即最大的分量方向上的距离

可见p取值越大,两点之间的距离越短。

(2) 问题:为什么切比雪夫距离

其实这个问题等价于:为什么 即 空间中的向量 它的切比雪 夫长度等于它的最大分量方向上的长度。

证明: 设

不妨设 即

注意:最大分量的长度唯一, 但最大分量可能并不唯 一,设有x^{(1)}, x^{(2)}, \ldots x^{(k)}等个分量的长度都等于\left|x^{(1)}\right|$

当 即 为 时

当 即 为非最大长度分量时

计算 的切比雪夫长度:

由于已知 等于0或1,且有k个分量结果为1, 所以

因此

即 得证。

以上证明参考

(3) 平面上的圆

在平面上的图像:c086400f34a74934af558807d2aae6e6.png如果圆形的定义是 “到某一点距离相等的曲线围成的图形” ,那么在不同的距离度量下,圆形的形 状可以完全不同。为何 正则化在平面上的等高线为同心正方形, 不就很容易理解吗?

  1. k值选择

李航老师书中引入了“近似误差”和“估计误差”两个概念,但没有给出具体定义。

这里简单总结一下:

右侧两项分别是 “估计误差” 和 “近似误差”

  • 估计误差:训练集与无限数据集得到的模型的差距
  • 近似误差:限制假设空间与无限制假设空间得到的模型的差距

    • k值越小 单个样本影响越大 模型越复杂 假设空间越大 近似误差越小 (估计误差 越大),容易过拟合;
    • k值越大 单个样本影响越小 模型越简单 假设空间越小 近似误差越大(估计误差 越小),容易欠拟合。

一般通过交叉验证来确定最优的 k 值。

  1. 决策规则

k 近邻的决策规则就是 “多数表决” ,即少数服从多数, 用数学表达式表示为

等号左侧为误分类率,要是误分类率最小,就要使得 最大, 即选择集合 中最多的一类。

三、kd树

kd 树的结构

kd树是一个二叉树结构,它的每一个节点记载了 [特征坐标, 切分轴, 指向左枝的指针, 指向右枝的指针] 。其中, 特征坐标是线性空间 中的一个点

切分轴由一个整数 表示, 这里 是我们在 维空间中沿第 维进行一次分割。节点的左枝和右枝分别都是 kd 树, 并且满足:如果 是左枝的一个特征坐标, 那么 并且如果 是右 枝的一个特征坐标,那么

给定一个数据样本集 和切分轴 以下递归算法将构建一个基于该数据集的 kd 树, 每一次循环制作一 个节点:

  • 如果 记录 中唯一的一个点为当前节点的特征数据, 并且不设左枝和右枝。 指集合 中元素 . 的数量) 如果
  • 如果

    • 将 S 内所有点按照第 个坐标的大小进行排序;
    • 选出该排列后的中位元素 (如果一共有偶数个元素, 则选择中位左边或右边的元素, 左或右并无影响),作为当前节点的特征坐标, 并且记录切分轴 将 设为在 中所有排列在中位元素之前的元素; 设为在 中所有排列在中位元素后的元素;
    • 当前节点的左枝设为以 为数据集并且 为切分轴制作出的 树; 当前节点的右枝设为以 为数据集并且 为切分轴制作出 的 kd 树。再设 这里, 我们想轮流沿着每一个维度进 行分割; 是因为一共有 个维度, 在 沿着最后一个维度进行分割之后再重新回到第一个维度。)

构造 kd 树的例子

上面抽象的定义和算法确实是很不好理解,举一个例子会清楚很多。首先随机在 中随机生成 13 个点作为我们的数据集。起始的切分轴 这里 对应 轴, 而 对应 轴。06a03dfef499dd149435a3a7481e8902.png首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标64398c4983945f1c7554abcb17b57bca.png

并且按照该点的x坐标将空间进行切分,所有 x 坐标小于 6.27 的数据用于构建左枝,x坐标大于 6.27 的点用于构建右枝。cb7d2d1a15f1f2a458bcdf789512d382.png节点。得到下面的树,左边的x 是指这该层的节点都是沿 x 轴进行分割的。12b992946d6dad8d821657d81aac43e9.png空间的切分如下80cc2ba4f9c71002d411879faf550bd6.png下一步中 对应 轴, 所以下面再按照 除标进行排序和切分,有732f6f5c1de2e066859c3c0554348ac0.png9c99a458eaddf97a01b9cadc487a3c97.png最后每一部分都只剩一个点,将他们记在最底部的节点中。因为不再有未被记录的点,所以不再进行切分。f7ac1d37dfb6633e3e3b536ae26d9205.png87ab0b5180b5b4a176e70a271432fb9c.png就此完成了 kd 树的构造。

kd 树上的 kNN 算法

给定一个构建于一个样本生的 kd 树, 下面的算法可以寻找距离某个点 最近的 个样本。

  1. 设 为一个有 个空位的列表, 用于保存已搜寻到的最近点.
  2. 根据 的坐标值和每个节点的切分向下搜素(也就是选,如果树的节点是按照 进行切分,并且 的 坐标小于 则向左枝进行搜索: 反之则走右枝)。
  3. 当达到一个底部节点时,将其标记为访问过. 如果 里不足 个点. 则将当前节点的特征坐标加人 如 果 L不为空并且当前节点 的特征与 的距离小于 里最长的距离,则用当前特征音换掉 中离 最远的点
  4. 如果当前节点不是整棵树最顶而节点, 执行 下(1):反之. 输出 算法完成. (1) . 向上爬一个节点。如果当前 (向上爬之后的) 节点未管被访问过, 将其标记为被访问过, 然后执行 1和2:如果当前节点被访 问过, 再次执行 (1)。

    1. 如果此时 里不足 个点, 则将节点特征加入 如果 中已满 个点, 且当前节点与 的距离小于 里最长的距离。则用节点特征豐换掉 中帝最远的点。
    2. 计算 和当前节点切分綫的距离。如果该距离大于等于 中距离 最远的距离井且 中已有 个点。则在切分线另一边不会有更近的点, 执行3: 如果该距离小于 中最远的距离或者 中不足 个点, 则切分綫另一边可能有更近的点, 因此在当前节点的另一个枝从 开始执行.

来看下面的例子:f11d5e01d5ce2a8ac22fce28c8cb6281.png首先执行1,我们按照切分找到最底部节点。首先,我们在顶部开始8d91e791bc0847fba17ddb8ba5ddffc9.png和这个节点的 x轴比较一下,5318f4aafe6910ff53ea251d8d007285.pngppp 的 x 轴更小。因此我们向左枝进行搜索:057fa3824dc7a2b0bc3f546e0b8b240a.png这次对比 y 轴,8a295ca9067758fae05e255e0c74d2ec.pngp 的 y 值更小,因此向左枝进行搜索:332b98fb03c2d1154733784d2134ec2b.png这个节点只有一个子枝,就不需要对比了。由此找到了最底部的节点 (−4.6,−10.55)。28717b2e1d10265c8b5f9fcf2342760e.png在二维图上是d09efa154746b760c368f00ade5bdf73.png此时我们执行2。将当前结点标记为访问过, 并记录下 访问过的节点就在二叉树 上显示为被划掉的好了。

然后执行 3,不是最顶端节点。执行 (1),我爬。上面的是 (−6.88,−5.4)。ca3fc5282ba265224fe46ca7f440f6a3.png8c72193420262cf08e9499c7e72edce4.png

执行 1,因为我们记录下的点只有一个,小于k=3,所以也将当前节点记录下,有 L=[(−4.6,−10.55),(−6.88,−5.4)]。再执行 2,因为当前节点的左枝是空的,所以直接跳过,回到步骤3。3看了一眼,好,不是顶部,交给你了,(1)。于是乎 (1) 又往上爬了一节。2c3120ade29c90e38e20be68a7bc47a4.png6c5c7c1b135df4934867a3294ed2be41.png1 说,由于还是不够三个点,于是将当前点也记录下,有 L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)。当然,当前结点变为被访问过的。

2又发现,当前节点有其他的分枝,并且经计算得出 p 点和 L 中的三个点的距离分别是 6.62,5.89,3.10,但是 p 和当前节点的分割线的距离只有 2.14,小于与 L 的最大距离:42193065fc68d5a3b8958a34e14add66.png因此,在分割线的另一端可能有更近的点。于是我们在当前结点的另一个分枝从头执行 1。好,我们在红线这里:32eabf846b215bfe99509281d8566a10.png要用 p 和这个节点比较 x 坐标:d10a62dec73d3eec06d6f45e080d0f28.pngp 的x 坐标更大,因此探索右枝 (1.75,12.26),并且发现右枝已经是最底部节点,因此启动 2。65f2ae4dade2a1e985a0950a41b9a20e.png经计算,(1.75,12.26)与 p 的距离是 7.48,要大于 p 与 L 的距离,因此我们不将其放入记录中。1c8cdb5ac5e94246c41379fc706ea419.png然后 3 判断出不是顶端节点,呼出 (1),爬。5a7aea1ac85c3573b8d8c1d7100ab42e.png1出来一算,这个节点与 p 的距离是 4.91,要小于 p 与 L 的最大距离 6.62。d50a23922354bde4e9c1638ac3b7f22f.png因此,我们用这个新的节点替代 L 中离 p 最远的 (−4.6,−10.55)。06c76f100ba18817fde601d828bd3ca8.png然后 2又来了,我们比对 p 和当前节点的分割线的距离1ed8bfb515d276f97d3fecad88283315.png这个距离小于 L 与 p 的最小距离,因此我们要到当前节点的另一个枝执行 1。当然,那个枝只有一个点,直接到 2。9325d140947b0024c1d4c301bbce88a3.png计算距离发现这个点离 p 比 L 更远,因此不进行替代。5c093c495f19b3fdfae9d7cd9c78c7d0.png3发现不是顶点,所以呼出 (1)。我们向上爬,a3966725840bd179c696ab5ed7f67cc6.png这个是已经访问过的了,所以再来(1),217a2150ff1e2a913746c9c7d801cf54.png好,(1)再爬,4d19160779cf0848921c266c4c3c7682.png啊!到顶点了。所以完了吗?当然不,还没轮到 3 呢。现在是 1的回合。

我们进行计算比对发现顶端节点与p的距离比L还要更远,因此不进行更新。724be2077a5ac878d43d4c4e5c7d99a9.png然后是 2,计算 p 和分割线的距离发现也是更远。2daaf58ebfa0bfe9e1ce0717a642779f.png因此也不需要检查另一个分枝。

然后执行 3,判断当前节点是顶点,因此计算完成!输出距离 p 最近的三个样本是 L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)].

C实现

  1. #include
  2. #include
  3. #include
  4. #include
  5. #include
  6. #include "kdtree.h"
  7. static inline int is_leaf(struct kdnode *node){
  8. return node->left == node->right;
  9. }
  10. static inline void swap(long *a, long *b){
  11. long tmp = *a;
  12. *a = *b;
  13. *b = tmp;
  14. }
  15. static inline double square(double d){
  16. return d * d;
  17. }
  18. static inline double distance(double *c1, double *c2, int dim){
  19. double distance = 0;
  20. while (dim-- > 0) {
  21. distance += square(*c1++ - *c2++);
  22. }
  23. return distance;
  24. }
  25. static inline double knn_max(struct kdtree *tree){
  26. return tree->knn_list_head.prev->distance;
  27. }
  28. static inline double D(struct kdtree *tree, long index, int r){
  29. return tree->coord_table[index][r];
  30. }
  31. static inline int kdnode_passed(struct kdtree *tree, struct kdnode *node){
  32. return node != NULL ? tree->coord_passed[node->coord_index] : 1;
  33. }
  34. static inline int knn_search_on(struct kdtree *tree, int k, double value, double target){
  35. return tree->knn_num }
  36. static inline void coord_index_reset(struct kdtree *tree){
  37. long i;
  38. for (i = 0; i capacity; i++) {
  39. tree->coord_indexes[i] = i;
  40. }
  41. }
  42. static inline void coord_table_reset(struct kdtree *tree){
  43. long i;
  44. for (i = 0; i capacity; i++) {
  45. tree->coord_table[i] = tree->coords + i * tree->dim;
  46. }
  47. }
  48. static inline void coord_deleted_reset(struct kdtree *tree){
  49. memset(tree->coord_deleted, 0, tree->capacity);
  50. }
  51. static inline void coord_passed_reset(struct kdtree *tree){
  52. memset(tree->coord_passed, 0, tree->capacity);
  53. }
  54. static void coord_dump_all(struct kdtree *tree){
  55. long i, j;
  56. for (i = 0; i count; i++) {
  57. long index = tree->coord_indexes[i];
  58. double *coord = tree->coord_table[index];
  59. printf("(");
  60. for (j = 0; j dim; j++) {
  61. if (j != tree->dim - 1) {
  62. printf("%.2f,", coord[j]);
  63. } else {
  64. printf("%.2f)\n", coord[j]);
  65. }
  66. }
  67. }
  68. }
  69. static void coord_dump_by_indexes(struct kdtree *tree, long low, long high, int r){
  70. long i;
  71. printf("r=%d:", r);
  72. for (i = 0; i <= high; i++) {
  73. if (i printf("%8s", " ");
  74. } else {
  75. long index = tree->coord_indexes[i];
  76. printf("%8.2f", tree->coord_table[index][r]);
  77. }
  78. }
  79. printf("\n");
  80. }
  81. static void bubble_sort(struct kdtree *tree, long low, long high, int r){
  82. long i, flag = high + 1;
  83. long *indexes = tree->coord_indexes;
  84. while (flag > 0) {
  85. long len = flag;
  86. flag = 0;
  87. for (i = low + 1; i if (D(tree, indexes[i], r) 1], r)) {
  88. swap(indexes + i - 1, indexes + i);
  89. flag = i;
  90. }
  91. }
  92. }
  93. }
  94. static void insert_sort(struct kdtree *tree, long low, long high, int r){
  95. long i, j;
  96. long *indexes = tree->coord_indexes;
  97. for (i = low + 1; i <= high; i++) {
  98. long tmp_idx = indexes[i];
  99. double tmp_value = D(tree, indexes[i], r);
  100. j = i - 1;
  101. for (; j >= low && D(tree, indexes[j], r) > tmp_value; j--) {
  102. indexes[j + 1] = indexes[j];
  103. }
  104. indexes[j + 1] = tmp_idx;
  105. }
  106. }
  107. static void quicksort(struct kdtree *tree, long low, long high, int r){
  108. if (high - low <= 32) {
  109. insert_sort(tree, low, high, r);
  110. //bubble_sort(tree, low, high, r);
  111. return;
  112. }
  113. long *indexes = tree->coord_indexes;
  114. /* median of 3 */
  115. long mid = low + (high - low) / 2;
  116. if (D(tree, indexes[low], r) > D(tree, indexes[mid], r)) {
  117. swap(indexes + low, indexes + mid);
  118. }
  119. if (D(tree, indexes[low], r) > D(tree, indexes[high], r)) {
  120. swap(indexes + low, indexes + high);
  121. }
  122. if (D(tree, indexes[high], r) > D(tree, indexes[mid], r)) {
  123. swap(indexes + high, indexes + mid);
  124. }
  125. /* D(indexes[low]) <= D(indexes[high]) <= D(indexes[mid]) */
  126. double pivot = D(tree, indexes[high], r);
  127. /* 3-way partition
  128. * +---------+-----------+---------+-------------+---------+
  129. * | pivot | <=pivot | ? | >=pivot | pivot |
  130. * +---------+-----------+---------+-------------+---------+
  131. * low lt i j gt high
  132. */
  133. long i = low - 1;
  134. long lt = i;
  135. long j = high;
  136. long gt = j;
  137. for (; ;) {
  138. while (D(tree, indexes[++i], r) while (D(tree, indexes[--j], r) > pivot && j > low) {}
  139. if (i >= j) break;
  140. swap(indexes + i, indexes + j);
  141. if (D(tree, indexes[i], r) == pivot) swap(&indexes[++lt], &indexes[i]);
  142. if (D(tree, indexes[j], r) == pivot) swap(&indexes[--gt], &indexes[j]);
  143. }
  144. /* i == j or j + 1 == i */
  145. swap(indexes + i, indexes + high);
  146. /* Move equal elements to the middle of array */
  147. long x, y;
  148. for (x = low, j = i - 1; x <= lt && j > lt; x++, j--) swap(indexes + x, indexes + j);
  149. for (y = high, i = i + 1; y >= gt && i
  150. quicksort(tree, low, j - lt + x - 1, r);
  151. quicksort(tree, i + y - gt, high, r);
  152. }
  153. static struct kdnode *kdnode_alloc(double *coord, long index, int r){
  154. struct kdnode *node = malloc(sizeof(*node));
  155. if (node != NULL) {
  156. memset(node, 0, sizeof(*node));
  157. node->coord = coord;
  158. node->coord_index = index;
  159. node->r = r;
  160. }
  161. return node;
  162. }
  163. static void kdnode_free(struct kdnode *node){
  164. free(node);
  165. }
  166. static int coord_cmp(double *c1, double *c2, int dim){
  167. int i;
  168. double ret;
  169. for (i = 0; i ret = *c1++ - *c2++;
  170. if (fabs(ret) >= DBL_EPSILON) {
  171. return ret > 0 ? 1 : -1;
  172. }
  173. }
  174. if (fabs(ret) return 0;
  175. } else {
  176. return ret > 0 ? 1 : -1;
  177. }
  178. }
  179. static void knn_list_add(struct kdtree *tree, struct kdnode *node, double distance){
  180. if (node == NULL) return;
  181. struct knn_list *head = &tree->knn_list_head;
  182. struct knn_list *p = head->prev;
  183. if (tree->knn_num == 1) {
  184. if (p->distance > distance) {
  185. p = p->prev;
  186. }
  187. } else {
  188. while (p != head && p->distance > distance) {
  189. p = p->prev;
  190. }
  191. }
  192. if (p == head || coord_cmp(p->node->coord, node->coord, tree->dim)) {
  193. struct knn_list *log = malloc(sizeof(*log));
  194. if (log != NULL) {
  195. log->node = node;
  196. log->distance = distance;
  197. log->prev = p;
  198. log->next = p->next;
  199. p->next->prev = log;
  200. p->next = log;
  201. tree->knn_num++;
  202. }
  203. }
  204. }
  205. static void knn_list_adjust(struct kdtree *tree, struct kdnode *node, double distance){
  206. if (node == NULL) return;
  207. struct knn_list *head = &tree->knn_list_head;
  208. struct knn_list *p = head->prev;
  209. if (tree->knn_num == 1) {
  210. if (p->distance > distance) {
  211. p = p->prev;
  212. }
  213. } else {
  214. while (p != head && p->distance > distance) {
  215. p = p->prev;
  216. }
  217. }
  218. if (p == head || coord_cmp(p->node->coord, node->coord, tree->dim)) {
  219. struct knn_list *log = head->prev;
  220. /* Replace the original max one */
  221. log->node = node;
  222. log->distance = distance;
  223. /* Remove from the max position */
  224. head->prev = log->prev;
  225. log->prev->next = head;
  226. /* insert as a new one */
  227. log->prev = p;
  228. log->next = p->next;
  229. p->next->prev = log;
  230. p->next = log;
  231. }
  232. }
  233. static void knn_list_clear(struct kdtree *tree){
  234. struct knn_list *head = &tree->knn_list_head;
  235. struct knn_list *p = head->next;
  236. while (p != head) {
  237. struct knn_list *prev = p;
  238. p = p->next;
  239. free(prev);
  240. }
  241. tree->knn_num = 0;
  242. }
  243. static void resize(struct kdtree *tree){
  244. tree->capacity *= 2;
  245. tree->coords = realloc(tree->coords, tree->dim * sizeof(double) * tree->capacity);
  246. tree->coord_table = realloc(tree->coord_table, sizeof(double *) * tree->capacity);
  247. tree->coord_indexes = realloc(tree->coord_indexes, sizeof(long) * tree->capacity);
  248. tree->coord_deleted = realloc(tree->coord_deleted, sizeof(char) * tree->capacity);
  249. tree->coord_passed = realloc(tree->coord_passed, sizeof(char) * tree->capacity);
  250. coord_table_reset(tree);
  251. coord_index_reset(tree);
  252. coord_deleted_reset(tree);
  253. coord_passed_reset(tree);
  254. }
  255. static void kdnode_dump(struct kdnode *node, int dim){
  256. int i;
  257. if (node->coord != NULL) {
  258. printf("(");
  259. for (i = 0; i if (i != dim - 1) {
  260. printf("%.2f,", node->coord[i]);
  261. } else {
  262. printf("%.2f)\n", node->coord[i]);
  263. }
  264. }
  265. } else {
  266. printf("(none)\n");
  267. }
  268. }
  269. void kdtree_insert(struct kdtree *tree, double *coord){
  270. if (tree->count + 1 > tree->capacity) {
  271. resize(tree);
  272. }
  273. memcpy(tree->coord_table[tree->count++], coord, tree->dim * sizeof(double));
  274. }
  275. static void knn_pickup(struct kdtree *tree, struct kdnode *node, double *target, int k){
  276. double dist = distance(node->coord, target, tree->dim);
  277. if (tree->knn_num knn_list_add(tree, node, dist);
  278. } else {
  279. if (dist knn_list_adjust(tree, node, dist);
  280. } else if (fabs(dist - knn_max(tree)) knn_list_add(tree, node, dist);
  281. }
  282. }
  283. }
  284. static void kdtree_search_recursive(struct kdtree *tree, struct kdnode *node, double *target, int k, int *pickup){
  285. if (node == NULL || kdnode_passed(tree, node)) {
  286. return;
  287. }
  288. int r = node->r;
  289. if (!knn_search_on(tree, k, node->coord[r], target[r])) {
  290. return;
  291. }
  292. if (*pickup) {
  293. tree->coord_passed[node->coord_index] = 1;
  294. knn_pickup(tree, node, target, k);
  295. kdtree_search_recursive(tree, node->left, target, k, pickup);
  296. kdtree_search_recursive(tree, node->right, target, k, pickup);
  297. } else {
  298. if (is_leaf(node)) {
  299. *pickup = 1;
  300. } else {
  301. if (target[r] <= node->coord[r]) {
  302. kdtree_search_recursive(tree, node->left, target, k, pickup);
  303. kdtree_search_recursive(tree, node->right, target, k, pickup);
  304. } else {
  305. kdtree_search_recursive(tree, node->right, target, k, pickup);
  306. kdtree_search_recursive(tree, node->left, target, k, pickup);
  307. }
  308. }
  309. /* back track and pick up */
  310. if (*pickup) {
  311. tree->coord_passed[node->coord_index] = 1;
  312. knn_pickup(tree, node, target, k);
  313. }
  314. }
  315. }
  316. void kdtree_knn_search(struct kdtree *tree, double *target, int k){
  317. if (k > 0) {
  318. int pickup = 0;
  319. kdtree_search_recursive(tree, tree->root, target, k, &pickup);
  320. }
  321. }
  322. void kdtree_delete(struct kdtree *tree, double *coord){
  323. int r = 0;
  324. struct kdnode *node = tree->root;
  325. struct kdnode *parent = node;
  326. while (node != NULL) {
  327. if (node->coord == NULL) {
  328. if (parent->right->coord == NULL) {
  329. break;
  330. } else {
  331. node = parent->right;
  332. continue;
  333. }
  334. }
  335. if (coord[r] coord[r]) {
  336. parent = node;
  337. node = node->left;
  338. } else if (coord[r] > node->coord[r]) {
  339. parent = node;
  340. node = node->right;
  341. } else {
  342. int ret = coord_cmp(coord, node->coord, tree->dim);
  343. if (ret 0) {
  344. parent = node;
  345. node = node->left;
  346. } else if (ret > 0) {
  347. parent = node;
  348. node = node->right;
  349. } else {
  350. node->coord = NULL;
  351. break;
  352. }
  353. }
  354. r = (r + 1) % tree->dim;
  355. }
  356. }
  357. static void kdnode_build(struct kdtree *tree, struct kdnode **nptr, int r, long low, long high){
  358. if (low == high) {
  359. long index = tree->coord_indexes[low];
  360. *nptr = kdnode_alloc(tree->coord_table[index], index, r);
  361. } else if (low /* Sort and fetch the median to build a balanced BST */
  362. quicksort(tree, low, high, r);
  363. long median = low + (high - low) / 2;
  364. long median_index = tree->coord_indexes[median];
  365. struct kdnode *node = *nptr = kdnode_alloc(tree->coord_table[median_index], median_index, r);
  366. r = (r + 1) % tree->dim;
  367. kdnode_build(tree, &node->left, r, low, median - 1);
  368. kdnode_build(tree, &node->right, r, median + 1, high);
  369. }
  370. }
  371. static void kdtree_build(struct kdtree *tree){
  372. kdnode_build(tree, &tree->root, 0, 0, tree->count - 1);
  373. }
  374. void kdtree_rebuild(struct kdtree *tree){
  375. long i, j;
  376. size_t size_of_coord = tree->dim * sizeof(double);
  377. for (i = 0, j = 0; j count; i++, j++) {
  378. while (j count && tree->coord_deleted[j]) {
  379. j++;
  380. }
  381. if (i != j && j count) {
  382. memcpy(tree->coord_table[i], tree->coord_table[j], size_of_coord);
  383. tree->coord_deleted[i] = 0;
  384. }
  385. }
  386. tree->count = i;
  387. coord_index_reset(tree);
  388. kdtree_build(tree);
  389. }
  390. struct kdtree *kdtree_init(int dim){
  391. struct kdtree *tree = malloc(sizeof(*tree));
  392. if (tree != NULL) {
  393. tree->root = NULL;
  394. tree->dim = dim;
  395. tree->count = 0;
  396. tree->capacity = 65536;
  397. tree->knn_list_head.next = &tree->knn_list_head;
  398. tree->knn_list_head.prev = &tree->knn_list_head;
  399. tree->knn_list_head.node = NULL;
  400. tree->knn_list_head.distance = 0;
  401. tree->knn_num = 0;
  402. tree->coords = malloc(dim * sizeof(double) * tree->capacity);
  403. tree->coord_table = malloc(sizeof(double *) * tree->capacity);
  404. tree->coord_indexes = malloc(sizeof(long) * tree->capacity);
  405. tree->coord_deleted = malloc(sizeof(char) * tree->capacity);
  406. tree->coord_passed = malloc(sizeof(char) * tree->capacity);
  407. coord_index_reset(tree);
  408. coord_table_reset(tree);
  409. coord_deleted_reset(tree);
  410. coord_passed_reset(tree);
  411. }
  412. return tree;
  413. }
  414. static void kdnode_destroy(struct kdnode *node){
  415. if (node == NULL) return;
  416. kdnode_destroy(node->left);
  417. kdnode_destroy(node->right);
  418. kdnode_free(node);
  419. }
  420. void kdtree_destroy(struct kdtree *tree){
  421. kdnode_destroy(tree->root);
  422. knn_list_clear(tree);
  423. free(tree->coords);
  424. free(tree->coord_table);
  425. free(tree->coord_indexes);
  426. free(tree->coord_deleted);
  427. free(tree->coord_passed);
  428. free(tree);
  429. }
  430. #define _KDTREE_DEBUG
  431. #ifdef _KDTREE_DEBUG
  432. struct kdnode_backlog {
  433. struct kdnode *node;
  434. int next_sub_idx;
  435. };
  436. void kdtree_dump(struct kdtree *tree){
  437. int level = 0;
  438. struct kdnode *node = tree->root;
  439. struct kdnode_backlog nbl, *p_nbl = NULL;
  440. struct kdnode_backlog nbl_stack[KDTREE_MAX_LEVEL];
  441. struct kdnode_backlog *top = nbl_stack;
  442. for (; ;) {
  443. if (node != NULL) {
  444. /* Fetch the pop-up backlogged node's sub-id.
  445. * If not backlogged, fetch the first sub-id. */
  446. int sub_idx = p_nbl != NULL ? p_nbl->next_sub_idx : KDTREE_RIGHT_INDEX;
  447. /* Backlog should be left in next loop */
  448. p_nbl = NULL;
  449. /* Backlog the node */
  450. if (is_leaf(node) || sub_idx == KDTREE_LEFT_INDEX) {
  451. top->node = NULL;
  452. top->next_sub_idx = KDTREE_RIGHT_INDEX;
  453. } else {
  454. top->node = node;
  455. top->next_sub_idx = KDTREE_LEFT_INDEX;
  456. }
  457. top++;
  458. level++;
  459. /* Draw lines as long as sub_idx is the first one */
  460. if (sub_idx == KDTREE_RIGHT_INDEX) {
  461. int i;
  462. for (i = 1; i if (i == level - 1) {
  463. printf("%-8s", "+-------");
  464. } else {
  465. if (nbl_stack[i - 1].node != NULL) {
  466. printf("%-8s", "|");
  467. } else {
  468. printf("%-8s", " ");
  469. }
  470. }
  471. }
  472. kdnode_dump(node, tree->dim);
  473. }
  474. /* Move down according to sub_idx */
  475. node = sub_idx == KDTREE_LEFT_INDEX ? node->left : node->right;
  476. } else {
  477. p_nbl = top == nbl_stack ? NULL : --top;
  478. if (p_nbl == NULL) {
  479. /* End of traversal */
  480. break;
  481. }
  482. node = p_nbl->node;
  483. level--;
  484. }
  485. }
  486. }
  487. #endif

python

  1. class kdtree(object):
  2. # 创建 kdtree
  3. # point_list 是一个 list 的 pair,pair[0] 是一 tuple 的特征,pair[1] 是类别
  4. def __init__(self, point_list, depth=0, root=None):
  5. if len(point_list)>0:
  6. # 轮换按照树深度选择坐标轴
  7. k = len(point_list[0][0])
  8. axis = depth % k
  9. # 选中位线,切
  10. point_list.sort(key=lambda x:x[0][axis])
  11. median = len(point_list) // 2
  12. self.axis = axis
  13. self.root = root
  14. self.size = len(point_list)
  15. # 造节点
  16. self.node = point_list[median]
  17. # 递归造左枝和右枝
  18. if len(point_list[:median])>0:
  19. self.left = kdtree(point_list[:median], depth+1, self)
  20. else:
  21. self.left = None
  22. if len(point_list[median+1:])>0:
  23. self.right = kdtree(point_list[median+1:], depth+1, self)
  24. else:
  25. self.right = None
  26. # 记录是按哪个方向切的还有树根
  27. else:
  28. return None
  29. # 在树上加一点
  30. def insert(self, point):
  31. self.size += 1
  32. # 分析是左还是右,递归加在叶子上
  33. if point[0][self.axis]0][self.axis]:if self.left!=None:
  34. self.left.insert(point)else:
  35. self.left = kdtree([point], self.axis+1, self)else:if self.right!=None:
  36. self.right.insert(point)else:
  37. self.right = kdtree([point], self.axis+1, self)# 输入一点# 按切分寻找叶子def find_leaf(self, point):if self.left==None and self.right==None:return selfelif self.left==None:return self.right.find_leaf(point)elif self.right==None:return self.left.find_leaf(point)elif point[self.axis]0][self.axis]:return self.left.find_leaf(point)else:return self.right.find_leaf(point)# 查找最近的 k 个点,复杂度 O(DlogN),D是维度,N是树的大小# 输入一点、一距离函数、一k。距离函数默认是 L_2def knearest(self, point, k=1, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))):# 往下戳到最底叶
  38. leaf = self.find_leaf(point)# 从叶子网上爬return leaf.k_down_up(point, k, dist, result=[], stop=self, visited=None)# 从下往上爬函数,stop是到哪里去,visited是从哪里来def k_down_up(self, point,k, dist, result=[],stop=None, visited=None):# 选最长距离if result==[]:
  39. max_dist = 0else:
  40. max_dist = max([x[1] for x in result])
  41. other_result=[]# 如果离分界线的距离小于现有最大距离,或者数据点不够,就从另一边的树根开始刨if (self.left==visited and self.node[0][self.axis]-point[self.axis]and self.right!=None)\or (len(result)and self.left==visited and self.right!=None):
  42. other_result=self.right.knearest(point,k, dist)if (self.right==visited and point[self.axis]-self.node[0][self.axis]and self.left!=None)\or (len(result)and self.right==visited and self.left!=None):
  43. other_result=self.left.knearest(point, k, dist)# 刨出来的点放一起,选前 k
  44. result.append((self.node, dist(point, self.node[0])))
  45. result = sorted(result+other_result, key=lambda pair: pair[1])[:k]# 到停点就返回结果if self==stop:return result# 没有就带着现有结果接着往上爬else:return self.root.k_down_up(point,k, dist, result, stop, self)# 输入 特征、类别、k、距离函数# 返回这个点属于该类别的概率def kNN_prob(self, point, label, k, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))):
  46. nearests = self.knearest(point, k, dist)return float(len([pair for pair in nearests if pair[0][1]==label]))/float(len(nearests))# 输入 特征、k、距离函数# 返回该点概率最大的类别以及相对应的概率def kNN(self, point, k, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))):
  47. nearests = self.knearest(point, k , dist)
  48. statistics = {}for data in nearests:
  49. label = data[0][1]if label not in statistics:
  50. statistics[label] = 1else:
  51. statistics[label] += 1
  52. max_label = max(statistics.iteritems(), key=operator.itemgetter(1))[0]return max_label, float(statistics[max_label])/float(len(nearests))

发表评论

表情:
评论列表 (有 0 条评论,328人围观)

还没有评论,来说两句吧...

相关阅读

    相关 K近邻算法

    一、kNN算法的工作原理 官方解释:存在一个样本数据集,也称作训练样本集,并且样本中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系,输入没有标签的新数据后

    相关 k近邻算法matlab实现_k近邻算法

    k 近邻法 (k-NN) 是一种基于实例的学习方法,无法转化为对参数空间的搜索问题(参数最优化 问题)。它的特点是对特征空间进行搜索。除了k近邻法,本章还对以下几个问题进行较深

    相关 k-近邻算法

    从今天开始,与大家分享我学习《Machine Learning In Action》这本书的笔记与心得。我会将源码加以详细的注释,这是我自己学习的一个过程,也是想通过这种方式帮

    相关 K近邻分类算法

    K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的k个最相似

    相关 K-近邻算法(KNN)

         拜读大神的系列教程,大神好像姓崔(猜测),大神根据《机器学习实战》来讲解,讲的很清楚,读了大神的博客后,我也把我自己吸收的写下来,可能有很多错误之处,希望拍砖(拍轻点

    相关 knn(k近邻算法

    一、什么是knn算法 knn算法实际上是利用训练数据集对特征向量空间进行划分,并作为其分类的模型。其输入是实例的特征向量,输出为实例的类别。寻找最近的k个数据,推测新数据...