Kd-tree and Nearest neighbor (NN) search (2D case)

A Kd-tree, or K-dimensional tree, is a generalization of a binary search tree that stores points in a k-dimensional space. In computer science it is often used for organizing some number of points in a space with k dimensions. Kd-trees are very useful for range and nearest neighbor (NN) searches, it is a very common operation in computer vision, computational geometry, data mining, machine learning, DNA sequencing. In the current post we will deal with point sets in the two-dimensional Cartesian space, so all of our Kd-trees will be two-dimensional.

UpperAndLowerConvexHulls

Each level of a Kd-tree splits all children along a specific dimension by a hyperplane that is perpendicular to the corresponding axis. At the root of of tree all children will be split based on the first dimension: if the first dimension coordinate is less than the root it will be in the left sub-tree and if it is greater than the root it will obviously be in the right sub-tree. Each level down in the tree divides on the next dimension, returning to the first dimension once all others have been considered. The most efficient way to build a Kd-tree is to use a partitioning method like the Quick Sort uses to place the median point at the root and everything with a smaller one-dimensional value to the left and larger to the right. The procedure is repeated then on both the left and right sub-trees until the last trees to be partitioned are only composed of one element. More information about Kd-trees and references to other information sources can be found here

We are especially interested in Kd-trees due to their efficiency. Building a Kd-tree (considering number of dimensions k fixed, and dataset size n variable) has O(n \log n) complexity and O(k n) space complexity, the NN search is close to O(\log n), and the search of m nearest neighbors is close to O(\log n).

Just for recap of computational complexity, below is Big-O complexity chart showing the number of operations (y-axis) required to obtain a result as the number of elements (x-axis). The following computational complexity functions are presented: O(1) – constant time, O(\log n) – logarithmic time, O(n) – linear time, O(n \log n) – linearithmic time, O(n^2) – quadratic time, O(n^3) – cubic time, O(2^n) – exponential time, and O(n!) – factorial time. O(n!) is the worst complexity requiring 720 operations for just 6 elements, while O(1) is the best complexity which requires a constant number of operations for any number of elements.

To construct a Kd-tree we will use the Python code available on this page, it is quite simple and uses a meadian-finding sort. So-called named tuples are used there to keep a tree structure in memory. A named tuple is just a normal tuple with names (keys) for its elements. Named tuples assign meaning to each position in a tuple and allow for more readable, self-documenting code. Here is a simple example:

from collections import namedtuple

person = namedtuple('Person', 'name year gender')
a = person(name='Alexey', year=1985, gender='male')

print 'a = ', a
print 'a[0] = ', a[0]
print 'a.name = ', a.name

As you can see, in named tuples you can access each position by its index as well as its name. More information about named tuples can be found here. Here is the code for a Kd-tree construction:

from collections import namedtuple
from operator import itemgetter
from pprint import pformat

class Node(namedtuple('Node', 'location left_child right_child')):

    def __repr__(self):
        return pformat(tuple(self))

def kdtree(point_list, depth=0):
    """ build K-D tree
    :param point_list list of input points
    :param depth      current tree's depth
    :return tree node
    """

    # assumes all points have the same dimension
    try:
        k = len(point_list[0])
    except IndexError:
        return None

    # Select axis based on depth so that axis cycles through
    # all valid values
    axis = depth % k

    # Sort point list and choose median as pivot element
    point_list.sort(key=itemgetter(axis))
    median = len(point_list) // 2         # choose median

    # Create node and construct subtrees
    return Node(
        location=point_list[median],
        left_child=kdtree(point_list[:median], depth + 1),
        right_child=kdtree(point_list[median + 1:], depth + 1)
    )

Now we need to generate an input point-set for building a tree. As point coordinates we will use both integer and float values (one code line difference). Here is a function which generates a list of random points in the two-dimensional Cartesian space:


import random
import numpy as np

def generate_point_list(n, min_val, max_val):
    """ generate a list of random points
    :param n        number of points
    :param min_val  minimal value
    :return max_val maximal value
    """

    p = []

    for i in range(n):

        # coordinates as integer values
        p.append((random.randint(min_val,max_val),
                  random.randint(min_val,max_val)))

        # coordinates as float values
        #p.append((np.random.normal(random.randint(min_val,max_val), scale=0.5),
        #          np.random.normal(random.randint(min_val,max_val), scale=0.5)))

    return p

To build a two-dimensional Kd-tree the following parameters need to be specified: n – number of points in the input dataset, min\_val – minimal coordinate value, max\_val – maximal coordinate value:


n = 50        # number of points
min_val = 0   # minimal coordinate value
max_val = 20  # maximal coordinate value

point_list = generate_point_list(n, min_val, max_val)

# construct a K-D tree
tree = kdtree(point_list)

And of course we want to see how our tree looks like, and for this we need visualization. Here is my code for visualization of two-dimensional Kd-trees:


import matplotlib.pyplot as plt

# line width for visualization of K-D tree
line_width = [4., 3.5, 3., 2.5, 2., 1.5, 1., .5, 0.3]

def plot_tree(tree, min_x, max_x, min_y, max_y, prev_node, branch, depth=0):
    """ plot K-D tree
    :param tree      input tree to be plotted
    :param min_x
    :param max_x
    :param min_y
    :param max_y
    :param prev_node parent's node
    :param branch    True if left, False if right
    :param depth     tree's depth
    :return tree     node
    """

    cur_node = tree.location         # current tree's node
    left_branch = tree.left_child    # its left branch
    right_branch = tree.right_child  # its right branch

    # set line's width depending on tree's depth
    if depth > len(line_width)-1:
        ln_width = line_width[len(line_width)-1]
    else:
        ln_width = line_width[depth]

    k = len(cur_node)
    axis = depth % k

    # draw a vertical splitting line
    if axis == 0:

        if branch is not None and prev_node is not None:

            if branch:
                max_y = prev_node[1]
            else:
                min_y = prev_node[1]

        plt.plot([cur_node[0],cur_node[0]], [min_y,max_y], linestyle='-', color='red', linewidth=ln_width)

    # draw a horizontal splitting line
    elif axis == 1:

        if branch is not None and prev_node is not None:

            if branch:
                max_x = prev_node[0]
            else:
                min_x = prev_node[0]

        plt.plot([min_x,max_x], [cur_node[1],cur_node[1]], linestyle='-', color='blue', linewidth=ln_width)

    # draw the current node
    plt.plot(cur_node[0], cur_node[1], 'ko')

    # draw left and right branches of the current node
    if left_branch is not None:
        plot_tree(left_branch, min_x, max_x, min_y, max_y, cur_node, True, depth+1)

    if right_branch is not None:
        plot_tree(right_branch, min_x, max_x, min_y, max_y, cur_node, False, depth+1)

plt.figure("K-d Tree", figsize=(10., 10.))
plt.axis( [min_val-delta, max_val+delta, min_val-delta, max_val+delta] )

plt.grid(b=True, which='major', color='0.75', linestyle='--')
plt.xticks([i for i in range(min_val-delta, max_val+delta, 1)])
plt.yticks([i for i in range(min_val-delta, max_val+delta, 1)])

# draw the tree
plot_tree(kd_tree, min_val-delta, max_val+delta, min_val-delta, max_val+delta, None, None)

plt.title('K-D Tree')
plt.show()
plt.close()

The constructed tree might look as shown below. Red lines show vertical hyperplanes, while blue lines show horizontal hyperplanes. Line thickness corresponds to tree’s depth (the thinner the deeper).

Let’s now build a tree with more nodes:


n = 300       # number of points
min_val = 0   # minimal coordinate value
max_val = 20  # maximal coordinate value

Let’s also see how the both trees might look like in case of float coordinate values. For this a small change in generate_point_list() function needs to be done (see above).

Finally, we have our Kd-tree and are ready to use it. Here we will talk about the most common operation with Kd-trees – Nearest neighbor (NN) search.

Usually this task is formulated as follows. n points in some space S are given. We have to work with so-called quieries, which have dataset S and some point X (also called “target point”) as their parameters (X does not have to belong to S). Typical queries are “find m nearest points of X or “find all points in S at given distance R from X or closer”.

Depending on the problem, we may have:

  • Different number of dimensions – from one to thousands.
  • Different metric type (Euclidean, 1-norm, etc.). Do not forget – elements in the Kd-tree are not necessarily points in the Cartesian space.
  • Different sizes.

Here we will focus on the first query, i.e. we will look for the nearest neighbor of a given target point. In this problem the dataset S is considered fixed. X may vary from request to request, but S remains unchanged. It makes it possible to preprocess dataset and build a data structure (in our case Kd-tree) which accelerates the search procedure. All approaches promising performance better than O(n) rely on some kind of preprocessing. Searching for the nearest neighbor in a Kd-tree proceeds as follows:

  1. Starting with the root node, the algorithm moves down the tree recursively (it goes left or right depending on whether the point is less than or greater than the current node in the split dimension).
  2. Traversing the tree the algorithm saves the node featured by the shortest distance to the target point as the “current best”.
  3. Once the algorithm reaches a leaf node, it unwinds the recursion of the tree performing the following steps at each node:
    1. If the current node is closer than the current best, then it becomes the current best.
    2. The algorithm checks whether there could be any points on the other side of the splitting plane that are closer to the target point than the current best. This is done by intersecting the splitting hyperplane with a hypersphere around the target point. The sphere has a radius equal to the current nearest distance. Since the hyperplanes are all axis-aligned, this is implemented as a simple comparison to see whether the difference between the splitting coordinate of the target point and the current node is less than the distance from the target point to the current best. For this we will use so-called hyperrectangles: every hyperplane divides the current hyperrectangle into two pieces: “near hyperrectangle” where the target point belongs to and “further hyperrectangle” on the other side of the hyperplane.
      1. If the hypersphere crosses the plane, there could be nearer points on the other side of the plane. It means the algorithm must move down the other branch of the tree from the current node looking for closer points, following the same recursive process as the entire search.
      2. If the hypersphere doesn’t intersect the splitting plane, then the algorithm continues walking up the tree, and the entire branch on the other side of that node is eliminated.
  4. The search is complete when the algorithm finishes this procedure for the root node.

The Python code performing all these steps is presented below. Note that the nearest node and its distance to the target point are stored into global variables. Also the algorithm uses squared distances for comparison to avoid computing square roots.


nearest_nn = None           # nearest neighbor (NN)
distance_nn = float('inf')  # distance from NN to target

def nearest_neighbor_search(tree, target_point, hr, distance, nearest=None, depth=0):
    """ Find the nearest neighbor for the given point (claims O(log(n)) complexity)
    :param tree         K-D tree
    :param target_point given point for the NN search
    :param hr           splitting hyperplane
    :param distance     minimal distance
    :param nearest      nearest point
    :param depth        tree's depth
    """

    global nearest_nn
    global distance_nn

    if tree is None:
        return

    k = len(target_point)

    cur_node = tree.location         # current tree's node
    left_branch = tree.left_child    # its left branch
    right_branch = tree.right_child  # its right branch

    nearer_kd = further_kd = None
    nearer_hr = further_hr = None
    left_hr = right_hr = None

    # Select axis based on depth so that axis cycles through all valid values
    axis = depth % k

    # split the hyperplane depending on the axis
    if axis == 0:
        left_hr = [hr[0], (cur_node[0], hr[1][1])]
        right_hr = [(cur_node[0],hr[0][1]), hr[1]]

    if axis == 1:
        left_hr = [(hr[0][0], cur_node[1]), hr[1]]
        right_hr = [hr[0], (hr[1][0], cur_node[1])]

    # check which hyperplane the target point belongs to
    if target_point[axis] <= cur_node[axis]:
        nearer_kd = left_branch
        further_kd = right_branch
        nearer_hr = left_hr
        further_hr = right_hr

    if target_point[axis] > cur_node[axis]:
        nearer_kd = right_branch
        further_kd = left_branch
        nearer_hr = right_hr
        further_hr = left_hr

    # check whether the current node is closer
    dist = (cur_node[0] - target_point[0])**2 + (cur_node[1] - target_point[1])**2

    if dist < distance:
        nearest = cur_node
        distance = dist

    # go deeper in the tree
    nearest_neighbor_search(nearer_kd, target_point, nearer_hr, distance, nearest, depth+1)

    # once we reached the leaf node we check whether there are closer points
    # inside the hypersphere
    if distance < distance_nn:
        nearest_nn = nearest
        distance_nn = distance

    # a nearer point (px,py) could only be in further_kd (further_hr) -> explore it
    px = compute_closest_coordinate(target_point[0], further_hr[0][0], further_hr[1][0])
    py = compute_closest_coordinate(target_point[1], further_hr[1][1], further_hr[0][1])

    # check whether it is closer than the current nearest neighbor => whether a hypersphere crosses the hyperplane
    dist = (px - target_point[0])**2 + (py - target_point[1])**2

    # explore the further kd-tree / hyperplane if necessary
    if dist < distance_nn:
        nearest_neighbor_search(further_kd, target_point, further_hr, distance, nearest, depth+1)

The closest coordinate of the nighboring hyperplane is computed as:


def compute_closest_coordinate(value, range_min, range_max):
    """ Compute the closest coordinate for the neighboring hyperplane
    :param value     coordinate value (x or y) of the target point
    :param range_min minimal coordinate (x or y) of the neighboring hyperplane
    :param range_max maximal coordinate (x or y) of the neighboring hyperplane
    :return x or y coordinate
    """

    v = None

    if range_min < value < range_max:
        v = value

    elif value <= range_min:
        v = range_min

    elif value >= range_max:
        v = range_max

    return v

The next piece of code generates a random point in space S, performs the NN search, and visualizes the result:

import math

# generate a random point on the grid
point = (np.random.normal(random.randint(min_val,max_val), scale=0.5), np.random.normal(random.randint(min_val,max_val), scale=0.5))

delta = 2  # extension of the drawing range

hr = [(min_val-delta, max_val+delta), (max_val+delta, min_val-delta)]  # initial splitting plane
max_dist = float('inf')

# find the nearest neighbor
nearest_neighbor_search(kd_tree, point, hr, max_dist)

# draw the given point
plt.plot(point[0], point[1], marker='o', color='#ff007f')
circle = plt.Circle((point[0], point[1]), 0.3, facecolor='#ff007f', edgecolor='#ff007f', alpha=0.5)
plt.gca().add_patch(circle)

# draw the hypersphere around the target point
circle = plt.Circle((point[0], point[1]), math.sqrt(distance_nn), facecolor='#ffd83d', edgecolor='#ffd83d', alpha=0.5)
plt.gca().add_patch(circle)

# draw the found nearest neighbor
plt.plot(nearest_nn[0], nearest_nn[1], 'go')
circle = plt.Circle((nearest_nn[0], nearest_nn[1]), 0.3, facecolor='#33cc00', edgecolor='#33cc00', alpha=0.5)
plt.gca().add_patch(circle)

Results of the NN search for Kd-trees of n = 50 nodes with integer and float coordinate values are shown below. The red point is the target point, the green point is its NN, and the yellow circle shows the hypersphere (in 2D) around the target point.

For n = 300:

Finding the NN is a O(\log n) operation in case of randomly distributed points. The biggest advantage of the NN search using Kd-trees is that it allows us to eliminate many points from consideration and focus only on some tree’s branches. However, in high-dimensional spaces, the curse of dimensionality causes the algorithm to visit many more branches than in lower-dimensional spaces. In particular, when the number of points is only slightly higher than the number of dimensions, the algorithm is only slightly better than a linear search of all points.

Best wishes and feel free to use / improve the code,
Alexey

This entry was posted in Uncategorized. Bookmark the permalink.

3 Responses to Kd-tree and Nearest neighbor (NN) search (2D case)

  1. This is a beautifully written-up article – thank you so much! Your explanations are particularly clear – there are many articles on KD trees on the internet but many are incomplete, difficult to follow, or cope with adding KD points one at a time rather than balancing the tree up-front – as you do – by passing them all in at the start (for my own application I do know the list up front). Also, I really like the diagrams you’ve generated – very smart! Anyway…. thanks a lot! 🙂

  2. Great stuff, this code/write-up was super helpful.

  3. LUVCODING says:

    Excellent work, concise and clear. Very easy to understand

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s