A **Kd-tree**, or **K-dimensional tree**, is a generalization of a binary search tree that stores points in a *k*-dimensional space. In computer science it is often used for organizing some number of points in a space with *k* dimensions. Kd-trees are very useful for **range** and **nearest neighbor (NN) searches**, it is a very common operation in computer vision, computational geometry, data mining, machine learning, DNA sequencing. In the current post we will deal with point sets in the two-dimensional Cartesian space, so all of our Kd-trees will be two-dimensional.

Each level of a Kd-tree splits all children along a specific dimension by a hyperplane that is perpendicular to the corresponding axis. At the root of of tree all children will be split based on the first dimension: if the first dimension coordinate is less than the root it will be in the left sub-tree and if it is greater than the root it will obviously be in the right sub-tree. Each level down in the tree divides on the next dimension, returning to the first dimension once all others have been considered. The most efficient way to build a Kd-tree is to use a partitioning method like the Quick Sort uses to place the median point at the root and everything with a smaller one-dimensional value to the left and larger to the right. The procedure is repeated then on both the left and right sub-trees until the last trees to be partitioned are only composed of one element. More information about Kd-trees and references to other information sources can be found here

We are especially interested in Kd-trees due to their efficiency. Building a Kd-tree (considering number of dimensions fixed, and dataset size variable) has complexity and space complexity, the NN search is close to , and the search of nearest neighbors is close to .

Just for recap of computational complexity, below is Big-O complexity chart showing the number of operations (y-axis) required to obtain a result as the number of elements (x-axis). The following computational complexity functions are presented: – constant time, – logarithmic time, – linear time, – linearithmic time, – quadratic time, – cubic time, – exponential time, and – factorial time. is the worst complexity requiring operations for just elements, while is the best complexity which requires a constant number of operations for any number of elements.

To construct a Kd-tree we will use the Python code available on this page, it is quite simple and uses a meadian-finding sort. So-called named tuples are used there to keep a tree structure in memory. A named tuple is just a normal tuple with names (keys) for its elements. Named tuples assign meaning to each position in a tuple and allow for more readable, self-documenting code. Here is a simple example:

from collections import namedtuple person = namedtuple('Person', 'name year gender') a = person(name='Alexey', year=1985, gender='male') print 'a = ', a print 'a[0] = ', a[0] print 'a.name = ', a.name

As you can see, in named tuples you can access each position by its index as well as its name. More information about named tuples can be found here. Here is the code for a Kd-tree construction:

from collections import namedtuple from operator import itemgetter from pprint import pformat class Node(namedtuple('Node', 'location left_child right_child')): def __repr__(self): return pformat(tuple(self)) def kdtree(point_list, depth=0): """ build K-D tree :param point_list list of input points :param depth current tree's depth :return tree node """ # assumes all points have the same dimension try: k = len(point_list[0]) except IndexError: return None # Select axis based on depth so that axis cycles through # all valid values axis = depth % k # Sort point list and choose median as pivot element point_list.sort(key=itemgetter(axis)) median = len(point_list) // 2 # choose median # Create node and construct subtrees return Node( location=point_list[median], left_child=kdtree(point_list[:median], depth + 1), right_child=kdtree(point_list[median + 1:], depth + 1) )

Now we need to generate an input point-set for building a tree. As point coordinates we will use both integer and float values (one code line difference). Here is a function which generates a list of random points in the two-dimensional Cartesian space:

import random import numpy as np def generate_point_list(n, min_val, max_val): """ generate a list of random points :param n number of points :param min_val minimal value :return max_val maximal value """ p = [] for i in range(n): # coordinates as integer values p.append((random.randint(min_val,max_val), random.randint(min_val,max_val))) # coordinates as float values #p.append((np.random.normal(random.randint(min_val,max_val), scale=0.5), # np.random.normal(random.randint(min_val,max_val), scale=0.5))) return p

To build a two-dimensional Kd-tree the following parameters need to be specified: – number of points in the input dataset, – minimal coordinate value, – maximal coordinate value:

n = 50 # number of points min_val = 0 # minimal coordinate value max_val = 20 # maximal coordinate value point_list = generate_point_list(n, min_val, max_val) # construct a K-D tree tree = kdtree(point_list)

And of course we want to see how our tree looks like, and for this we need visualization. Here is my code for visualization of two-dimensional Kd-trees:

import matplotlib.pyplot as plt # line width for visualization of K-D tree line_width = [4., 3.5, 3., 2.5, 2., 1.5, 1., .5, 0.3] def plot_tree(tree, min_x, max_x, min_y, max_y, prev_node, branch, depth=0): """ plot K-D tree :param tree input tree to be plotted :param min_x :param max_x :param min_y :param max_y :param prev_node parent's node :param branch True if left, False if right :param depth tree's depth :return tree node """ cur_node = tree.location # current tree's node left_branch = tree.left_child # its left branch right_branch = tree.right_child # its right branch # set line's width depending on tree's depth if depth > len(line_width)-1: ln_width = line_width[len(line_width)-1] else: ln_width = line_width[depth] k = len(cur_node) axis = depth % k # draw a vertical splitting line if axis == 0: if branch is not None and prev_node is not None: if branch: max_y = prev_node[1] else: min_y = prev_node[1] plt.plot([cur_node[0],cur_node[0]], [min_y,max_y], linestyle='-', color='red', linewidth=ln_width) # draw a horizontal splitting line elif axis == 1: if branch is not None and prev_node is not None: if branch: max_x = prev_node[0] else: min_x = prev_node[0] plt.plot([min_x,max_x], [cur_node[1],cur_node[1]], linestyle='-', color='blue', linewidth=ln_width) # draw the current node plt.plot(cur_node[0], cur_node[1], 'ko') # draw left and right branches of the current node if left_branch is not None: plot_tree(left_branch, min_x, max_x, min_y, max_y, cur_node, True, depth+1) if right_branch is not None: plot_tree(right_branch, min_x, max_x, min_y, max_y, cur_node, False, depth+1) plt.figure("K-d Tree", figsize=(10., 10.)) plt.axis( [min_val-delta, max_val+delta, min_val-delta, max_val+delta] ) plt.grid(b=True, which='major', color='0.75', linestyle='--') plt.xticks([i for i in range(min_val-delta, max_val+delta, 1)]) plt.yticks([i for i in range(min_val-delta, max_val+delta, 1)]) # draw the tree plot_tree(kd_tree, min_val-delta, max_val+delta, min_val-delta, max_val+delta, None, None) plt.title('K-D Tree') plt.show() plt.close()

The constructed tree might look as shown below. Red lines show vertical hyperplanes, while blue lines show horizontal hyperplanes. Line thickness corresponds to tree’s depth (the thinner the deeper).

Let’s now build a tree with more nodes:

n = 300 # number of points min_val = 0 # minimal coordinate value max_val = 20 # maximal coordinate value

Let’s also see how the both trees might look like in case of float coordinate values. For this a small change in *generate_point_list()* function needs to be done (see above).

Finally, we have our Kd-tree and are ready to use it. Here we will talk about the most common operation with Kd-trees – Nearest neighbor (NN) search.

Usually this task is formulated as follows. points in some space are given. We have to work with so-called *quieries*, which have dataset and some point (also called *“target point”*) as their parameters ( does not have to belong to ). Typical queries are “*find nearest points of “* or “find all points in at given distance from or closer”.

Depending on the problem, we may have:

- Different number of dimensions – from one to thousands.
- Different metric type (Euclidean, 1-norm, etc.). Do not forget – elements in the Kd-tree are not necessarily points in the Cartesian space.
- Different sizes.

Here we will focus on the first query, i.e. we will look for the nearest neighbor of a given target point. In this problem the dataset is considered fixed. may vary from request to request, but remains unchanged. It makes it possible to preprocess dataset and build a data structure (in our case Kd-tree) which accelerates the search procedure. All approaches promising performance better than rely on some kind of preprocessing. Searching for the nearest neighbor in a Kd-tree proceeds as follows:

- Starting with the root node, the algorithm moves down the tree recursively (it goes left or right depending on whether the point is less than or greater than the current node in the split dimension).
- Traversing the tree the algorithm saves the node featured by the shortest distance to the target point as the “current best”.
- Once the algorithm reaches a leaf node, it unwinds the recursion of the tree performing the following steps at each node:
- If the current node is closer than the current best, then it becomes the current best.
- The algorithm checks whether there could be any points on the other side of the splitting plane that are closer to the target point than the current best. This is done by intersecting the splitting hyperplane with a hypersphere around the target point. The sphere has a radius equal to the current nearest distance. Since the hyperplanes are all axis-aligned, this is implemented as a simple comparison to see whether the difference between the splitting coordinate of the target point and the current node is less than the distance from the target point to the current best. For this we will use so-called
*hyperrectangles*: every hyperplane divides the current hyperrectangle into two pieces:*“near hyperrectangle”*where the target point belongs to and*“further hyperrectangle”*on the other side of the hyperplane.- If the hypersphere crosses the plane, there could be nearer points on the other side of the plane. It means the algorithm must move down the other branch of the tree from the current node looking for closer points, following the same recursive process as the entire search.
- If the hypersphere doesn’t intersect the splitting plane, then the algorithm continues walking up the tree, and the entire branch on the other side of that node is eliminated.

- The search is complete when the algorithm finishes this procedure for the root node.

The Python code performing all these steps is presented below. Note that the nearest node and its distance to the target point are stored into global variables. Also the algorithm uses squared distances for comparison to avoid computing square roots.

nearest_nn = None # nearest neighbor (NN) distance_nn = float('inf') # distance from NN to target def nearest_neighbor_search(tree, target_point, hr, distance, nearest=None, depth=0): """ Find the nearest neighbor for the given point (claims O(log(n)) complexity) :param tree K-D tree :param target_point given point for the NN search :param hr splitting hyperplane :param distance minimal distance :param nearest nearest point :param depth tree's depth """ global nearest_nn global distance_nn if tree is None: return k = len(target_point) cur_node = tree.location # current tree's node left_branch = tree.left_child # its left branch right_branch = tree.right_child # its right branch nearer_kd = further_kd = None nearer_hr = further_hr = None left_hr = right_hr = None # Select axis based on depth so that axis cycles through all valid values axis = depth % k # split the hyperplane depending on the axis if axis == 0: left_hr = [hr[0], (cur_node[0], hr[1][1])] right_hr = [(cur_node[0],hr[0][1]), hr[1]] if axis == 1: left_hr = [(hr[0][0], cur_node[1]), hr[1]] right_hr = [hr[0], (hr[1][0], cur_node[1])] # check which hyperplane the target point belongs to if target_point[axis] <= cur_node[axis]: nearer_kd = left_branch further_kd = right_branch nearer_hr = left_hr further_hr = right_hr if target_point[axis] > cur_node[axis]: nearer_kd = right_branch further_kd = left_branch nearer_hr = right_hr further_hr = left_hr # check whether the current node is closer dist = (cur_node[0] - target_point[0])**2 + (cur_node[1] - target_point[1])**2 if dist < distance: nearest = cur_node distance = dist # go deeper in the tree nearest_neighbor_search(nearer_kd, target_point, nearer_hr, distance, nearest, depth+1) # once we reached the leaf node we check whether there are closer points # inside the hypersphere if distance < distance_nn: nearest_nn = nearest distance_nn = distance # a nearer point (px,py) could only be in further_kd (further_hr) -> explore it px = compute_closest_coordinate(target_point[0], further_hr[0][0], further_hr[1][0]) py = compute_closest_coordinate(target_point[1], further_hr[1][1], further_hr[0][1]) # check whether it is closer than the current nearest neighbor => whether a hypersphere crosses the hyperplane dist = (px - target_point[0])**2 + (py - target_point[1])**2 # explore the further kd-tree / hyperplane if necessary if dist < distance_nn: nearest_neighbor_search(further_kd, target_point, further_hr, distance, nearest, depth+1)

The closest coordinate of the nighboring hyperplane is computed as:

def compute_closest_coordinate(value, range_min, range_max): """ Compute the closest coordinate for the neighboring hyperplane :param value coordinate value (x or y) of the target point :param range_min minimal coordinate (x or y) of the neighboring hyperplane :param range_max maximal coordinate (x or y) of the neighboring hyperplane :return x or y coordinate """ v = None if range_min < value < range_max: v = value elif value <= range_min: v = range_min elif value >= range_max: v = range_max return v

The next piece of code generates a random point in space , performs the NN search, and visualizes the result:

import math # generate a random point on the grid point = (np.random.normal(random.randint(min_val,max_val), scale=0.5), np.random.normal(random.randint(min_val,max_val), scale=0.5)) delta = 2 # extension of the drawing range hr = [(min_val-delta, max_val+delta), (max_val+delta, min_val-delta)] # initial splitting plane max_dist = float('inf') # find the nearest neighbor nearest_neighbor_search(kd_tree, point, hr, max_dist) # draw the given point plt.plot(point[0], point[1], marker='o', color='#ff007f') circle = plt.Circle((point[0], point[1]), 0.3, facecolor='#ff007f', edgecolor='#ff007f', alpha=0.5) plt.gca().add_patch(circle) # draw the hypersphere around the target point circle = plt.Circle((point[0], point[1]), math.sqrt(distance_nn), facecolor='#ffd83d', edgecolor='#ffd83d', alpha=0.5) plt.gca().add_patch(circle) # draw the found nearest neighbor plt.plot(nearest_nn[0], nearest_nn[1], 'go') circle = plt.Circle((nearest_nn[0], nearest_nn[1]), 0.3, facecolor='#33cc00', edgecolor='#33cc00', alpha=0.5) plt.gca().add_patch(circle)

Results of the NN search for Kd-trees of nodes with integer and float coordinate values are shown below. The red point is the target point, the green point is its NN, and the yellow circle shows the hypersphere (in 2D) around the target point.

For :

Finding the NN is a operation in case of randomly distributed points. The biggest advantage of the NN search using Kd-trees is that it allows us to eliminate many points from consideration and focus only on some tree’s branches. However, in high-dimensional spaces, the curse of dimensionality causes the algorithm to visit many more branches than in lower-dimensional spaces. In particular, when the number of points is only slightly higher than the number of dimensions, the algorithm is only slightly better than a linear search of all points.

Best wishes and feel free to use / improve the code,

Alexey

This is a beautifully written-up article – thank you so much! Your explanations are particularly clear – there are many articles on KD trees on the internet but many are incomplete, difficult to follow, or cope with adding KD points one at a time rather than balancing the tree up-front – as you do – by passing them all in at the start (for my own application I do know the list up front). Also, I really like the diagrams you’ve generated – very smart! Anyway…. thanks a lot! đź™‚

Great stuff, this code/write-up was super helpful.

Excellent work, concise and clear. Very easy to understand