【统计学习方法】k近邻 kd树的python实现

深藏阁楼爱情的钟 2022-06-10 10:58 310阅读 0赞

前言

代码可在Github上下载:代码下载

k近邻可以算是机器学习中易于理解、实现的一个算法了,《机器学习实战》的第一章便是以它作为介绍来入门。而k近邻的算法可以简述为通过遍历数据集的每个样本进行距离测量,并找出距离最小的k个点。但是这样一来一旦样本数目庞大的时候,就容易造成大量的计算。

所以需要将数据用树形结构存储,以便快速检索,这也就是本文要阐述的kd树。

实现

分为两部分,一个是kd树建立,一个是kd树的搜索。

#

kd树建立

  1. # --*-- coding:utf-8 --*--
  2. import numpy as np

先定义一下字符集还有包。

首先我们先实现一个结点类,用来表示kd。

  1. class Node:
  2. def __init__(self, data, lchild = None, rchild = None):
  3. self.data = data
  4. self.lchild = lchild
  5. self.rchild = rchild

一个结点包含着结点域,左孩子,右孩子。(如果不熟二叉树的话建议先看一些数据结构二叉树的相关知识,以及先序遍历,中序遍历还有后序遍历的相关代码)

二叉树相关代码(C语言实现)

然后是创建kd树的代码,主要根据P41,算法3.2来实现的。

  1. def create(self, dataSet, depth): #创建kd树,返回根结点
  2. if (len(dataSet) > 0):
  3. m, n = np.shape(dataSet) #求出样本行,列
  4. midIndex = m / 2 #中间数的索引位置
  5. axis = depth % n #判断以哪个轴划分数据,对应书中算法3.2(2)公式j()
  6. sortedDataSet = self.sort(dataSet, axis) #进行排序
  7. node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
  8. # print sortedDataSet[midIndex]
  9. leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2个副本
  10. rightDataSet = sortedDataSet[midIndex+1 :]
  11. print leftDataSet
  12. print rightDataSet
  13. node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
  14. node.rchild = self.create(rightDataSet, depth+1)
  15. return node
  16. else:
  17. return None

以上的代码通过看注释应该可以了解一二,其中需要按轴j(mod k)+1,也就是【depth(深度) mod n(特征数)+1】为轴划分中位数,然后决定插入数据到左结点,右结点。然后注意一下为什么上面的按轴划分的公式是【depth(深度) mod n(特征数)】,这是因为python的数组下标是从0开始的。

  1. def sort(self, dataSet, axis): #采用冒泡排序,利用aixs作为轴进行划分
  2. sortDataSet = dataSet[:] #由于不能破坏原样本,此处建立一个副本
  3. m, n = np.shape(sortDataSet)
  4. for i in range(m):
  5. for j in range(0, m - i - 1):
  6. if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
  7. temp = sortDataSet[j]
  8. sortDataSet[j] = sortDataSet[j+1]
  9. sortDataSet[j+1] = temp
  10. print sortDataSet
  11. return sortDataSet

创建树的时候为了找中位数,需要按轴(某一维度)排序,找出中间那个数。这里我用了冒泡排序。

  1. def preOrder(self, node):
  2. if node != None:
  3. print "tttt->%s" % node.data
  4. self.preOrder(node.lchild)
  5. self.preOrder(node.rchild)

当然我选择了先序遍历来简单检查下树的创建有没有问题。(看下这棵树能否正常遍历,这步可忽略)

kd树搜索

  1. def search(self, tree, x): #搜索
  2. self.nearestPoint = None #保存最近的点
  3. self.nearestValue = 0 #保存最近的值
  4. def travel(node, depth = 0): #递归搜索
  5. if node != None: #递归终止条件
  6. n = len(x) #特征数
  7. axis = depth % n #计算轴
  8. if x[axis] < node.data[axis]: #如果数据小于结点,则往左结点找
  9. travel(node.lchild, depth+1)
  10. else:
  11. travel(node.rchild, depth+1)
  12. #以下是递归完毕,对应算法3.3(3)
  13. distNodeAndX = self.dist(x, node.data) #目标和节点的距离判断
  14. if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
  15. self.nearestPoint = node.data
  16. self.nearestValue = distNodeAndX
  17. elif (self.nearestValue > distNodeAndX):
  18. self.nearestPoint = node.data
  19. self.nearestValue = distNodeAndX
  20. print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
  21. if (abs(x[axis] - node.data[axis]) <= self.nearestValue): #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
  22. if x[axis] < node.data[axis]:
  23. travel(node.rchild, depth+1)
  24. else:
  25. travel(node.lchild, depth + 1)
  26. travel(tree)
  27. return self.nearestPoint
  28. def dist(self, x1, x2): #欧式距离的计算
  29. return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

搜索树的时候比较麻烦,首先先说下原理吧。

(1) 在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树。若目标点当前维的坐标值小于切分点的坐标值,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止;
(2) 以此叶结点为“当前最近点”;
(3) 递归的向上回退,在每个结点进行以下操作:
  (a) 如果该结点保存的实例点比当前最近点距目标点更近,则以该实例点为“当前最近点”;
  (b) 当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一个子结点对应的区域是否有更近的点。具体的,检查另一个子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。如果相交,可能在另一个子结点对应的区域内存在距离目标更近的点,移动到另一个子结点。接着,递归的进行最近邻搜索。如果不相交,向上回退。
(4) 当回退到根结点时,搜索结束。最后的“当前最近点”即为x的最近邻点。

注意了,先按步骤找到叶结点,然后回朔的时候要做两件事,(a)是更新最新点,(b)是检查是否需要检查父结节点的另外一个结点的区域。

  1. if x[axis] < node.data[axis]: #如果数据小于结点,则往左结点找
  2. travel(node.lchild, depth+1)
  3. else:
  4. travel(node.rchild, depth+1)

这段是类似于二叉查找树的过程,直至查找到叶子节点。

  1. #以下是递归完毕后,往父结点方向回朔,对应算法3.3(3)
  2. distNodeAndX = self.dist(x, node.data) #目标和节点的距离判断
  3. if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
  4. self.nearestPoint = node.data
  5. self.nearestValue = distNodeAndX
  6. elif (self.nearestValue > distNodeAndX):
  7. self.nearestPoint = node.data
  8. self.nearestValue = distNodeAndX
  9. print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
  10. if (abs(x[axis] - node.data[axis]) <= self.nearestValue): #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
  11. if x[axis] < node.data[axis]:
  12. travel(node.rchild, depth+1)
  13. else:
  14. travel(node.lchild, depth + 1)

这段代码,就是P43算法3.3(3)中的内容。

(a)容易实现,但是(b)的原理是判断目标点和最近的一个点的距离为半径画一个圆(就如书本P44图3.5,目标点S和当前最近点D形成了一个圆),是否跟父结点按轴分的那条线(也就是圆内的那条直线)有交集。

说白了,就是公式:|目标值(按轴读值) - 父节点(按轴读值)| < 最近的值(圆的半径),这里按轴读取就是P44图3.5中的x的y轴的值,然后减去相交的那条直线y轴的值,看是否小于半径。

注意:评论里有说这里的node.data不知道是指示哪个结点。这里要说明的是,这个node并不是父节点,而是当前结点。这里如果你对数据结构的二叉树不太熟的话,是不太容易get到这个点的。我只能稍微说下。

“这里应该了解下二叉查找树的过程”

如果找到了的话,把另一结点重新递归一次就好了。对应以下代码:

  1. travel(node.rchild, depth+1)

最后在github贴出全部代码(如果方便的话麻烦给个赞吧,您的支持就是我前进的动力),然后来运行一下代码(这段代码在python3.5下成功运行)。

KNN(KDtree)代码下载

结果输出(5,4)

发表评论

表情:
评论列表 (有 0 条评论,310人围观)

还没有评论,来说两句吧...

相关阅读