首页 > 代码库 > KD Tree算法

KD Tree算法

参考:http://blog.csdn.net/v_july_v/article/details/8203674

#!/user/bin/env python
# -*- coding:utf8 -*-

__author__ = ‘zky@msn.cn‘

import sys
import numpy
import heapq
import Queue

class KDNode(object):
    def __init__(self, name, feature):
        self.name = name
        self.ki = -1
        self.is_leaf = False
        self.feature = feature
        self.kd_left = None
        self.kd_right = None

    def traverse(self, seq, order=‘in‘):
        if order == ‘in‘:
            if self.kd_left:
                self.kd_left.traverse(seq, order)
            seq.append(self)
            if self.kd_right:
                self.kd_right.traverse(seq, order)
        elif order == ‘pre‘:
            seq.append(self)
            if self.kd_left:
                self.kd_left.traverse(seq, order)
            if self.kd_right:
                self.kd_right.traverse(seq, order)
        elif order == ‘post‘:
            if self.kd_left:
                self.kd_left.traverse(seq, order)
            if self.kd_right:
                self.kd_right.traverse(seq, order)
            seq.append(self)
        else:
            assert(False)

class NodeDistance(object):
    def __init__(self, kd_node, distance):
        self.kd_node = kd_node
        self.distance = distance

    # here i use a reversed result, because heapq can support only min heap
    def __cmp__(self, other):
        ret = other.distance - self.distance
        if ret > 0:
            return 1
        elif ret < 0:
            return -1
        else:
            return 0

def euclidean_distance(node1, node2):
    assert len(node1.feature) == len(node2.feature)
    sum = 0
    for i in xrange(len(node1.feature)):
        sum += numpy.square(node1.feature[i] - node2.feature[i])
    return numpy.sqrt(sum)

class KDTree(object):
    # n is num of dimension
    def __init__(self, nodes, n):
        self.root = self.build_kdtree(nodes, n)
        self.n = n

    def build_kdtree(self, nodes, n):
        if len(nodes) == 0:
            return None
        max_var = 0
        index = 0
        for i in xrange(n):
            features_n = map(lambda node : node.feature[i], nodes)
            var = numpy.var(features_n)
            if var > max_var:
                max_var = var
                index = i
        sorted_nodes = sorted(nodes, key=lambda node: node.feature[index])
        mid = len(sorted_nodes)/2
        root = sorted_nodes[mid]
        left_nodes = sorted_nodes[:mid]
        right_nodes = sorted_nodes[mid+1:]

        root.ki = index
        if len(left_nodes) == 0 and len(right_nodes) == 0:
            root.is_leaf = True
        root.kd_left = self.build_kdtree(left_nodes, n)
        root.kd_right = self.build_kdtree(right_nodes, n)
        return root

    def traverse_kdtree(self, order=‘in‘):
        seq = []
        self.root.traverse(seq, order)
        print map(lambda n : n.name, seq)

    # return a list of NodeDistance sorded by distance
    def kdtree_bbf_knn(self, target, k):
        if len(target.feature) != self.n:
            return None
        knn = []
        priority_queue = Queue.LifoQueue()
        priority_queue.put(self.root)
        while not priority_queue.empty():
            expl = priority_queue.get()
            while expl:
                ki = expl.ki
                kv = expl.feature[ki]

                if expl.name != target.name: # ignore target node itself
                    # save a maybe result
                    distance = euclidean_distance(expl, target)
                    nd = NodeDistance(expl, distance)
                    assert len(knn) <= k
                    if len(knn) == k:
                        if distance < knn[0].distance:
                            heapq.heapreplace(knn, nd)
                    else: # len(knn) < k
                        heapq.heappush(knn, nd)

                unexpl = None
                # find next expl
                if target.feature[ki] <= kv: # left
                    unexpl = expl.kd_right
                    expl = expl.kd_left
                else:
                    unexpl = expl.kd_left
                    expl = expl.kd_right

                # ignore nodes over a long distance bin
                if unexpl:
                    # save a maybe next expl 
                    if len(knn) < k:
                        priority_queue.put(unexpl)
                    elif (len(knn) == k) and (abs(kv - target.feature[ki]) < knn[0].distance):
                        priority_queue.put(unexpl)
        ret = []
        for i in xrange(len(knn)):
            node = heapq.heappop(knn)
            ret.insert(0, node)
        return ret

if __name__ == ‘__main__‘:
    f1 = [7, 2]
    f2 = [5, 4]
    f3 = [9, 6]
    f4 = [2, 3]
    f5 = [4, 7]
    f6 = [8, 1]
    fx = [2, 4.5]
    n1 = KDNode(‘f1‘, f1)
    n2 = KDNode(‘f2‘, f2)
    n3 = KDNode(‘f3‘, f3)
    n4 = KDNode(‘f4‘, f4)
    n5 = KDNode(‘f5‘, f5)
    n6 = KDNode(‘f6‘, f6)
    nx = KDNode(‘fx‘, fx)

    n1_distance = NodeDistance(n4, 1.5)
    n2_distance = NodeDistance(n5, 3.2)
    n3_distance = NodeDistance(n2, 3.04)
    assert n1_distance > n2_distance
    assert n1_distance > n3_distance
    assert n2_distance < n3_distance

    tree = KDTree([n1, n2, n3, n4, n5, n6, nx], 2)
    tree.traverse_kdtree(‘in‘)
    knn = tree.kdtree_bbf_knn(nx, 3)
    print map(lambda n : (n.kd_node.name, n.distance), knn)

 

KD Tree算法