Skip to content

glassbox.models.neighbors.index._kdtree

KDTreeIndex — space-partitioning tree for efficient nearest-neighbor queries.


KDTreeIndex

KDTreeIndex(metric, leaf_size=30)

Bases: BaseIndex

Initialize the KDTreeIndex.

Parameters:

Name Type Description Default
metric DistanceMetric

The distance metric to use.

required
leaf_size int

Number of points at which to switch to brute-force.

30
Source code in glassbox/models/neighbors/index/_kdtree.py
def __init__(self, metric: DistanceMetric, leaf_size: int = 30) -> None:
    """
    Initialize the KDTreeIndex.

    Parameters
    ----------
    metric : DistanceMetric
        The distance metric to use.
    leaf_size : int, default=30
        Number of points at which to switch to brute-force.
    """
    super().__init__(metric)
    self.root: KDNode | None = None
    self.leaf_size: int = leaf_size
    self.X_train: np.ndarray | None = None

build

build(X)

Build the KD-Tree structure from the training data.

Parameters:

Name Type Description Default
X ndarray

Training data, shape (n_samples, n_features).

required
Source code in glassbox/models/neighbors/index/_kdtree.py
def build(self, X: np.ndarray) -> None:
    """
    Build the KD-Tree structure from the training data.

    Parameters
    ----------
    X : np.ndarray
        Training data, shape (n_samples, n_features).
    """
    self.X_train = np.asarray(X, dtype=float)
    indices = np.arange(len(self.X_train))
    self.root = self._build_tree(indices, depth=0)

query

query(x, k)

Query the KD-Tree for the k nearest neighbors.

Parameters:

Name Type Description Default
x ndarray

Query point(s), shape (n_features,) or (n_queries, n_features).

required
k int

Number of nearest neighbors to retrieve.

required

Returns:

Type Description
ndarray

Indices of the k nearest neighbors.

Source code in glassbox/models/neighbors/index/_kdtree.py
def query(self, x: np.ndarray, k: int) -> np.ndarray:
    """
    Query the KD-Tree for the k nearest neighbors.

    Parameters
    ----------
    x : np.ndarray
        Query point(s), shape (n_features,) or (n_queries, n_features).
    k : int
        Number of nearest neighbors to retrieve.

    Returns
    -------
    np.ndarray
        Indices of the k nearest neighbors.
    """
    if self.X_train is None or self.root is None:
        raise ValueError("Index is not built yet")

    x_arr = np.asarray(x, dtype=float)
    single_query = x_arr.ndim == 1
    if single_query:
        x_arr = x_arr.reshape(1, -1)

    n_queries = x_arr.shape[0]
    nearest_idx = np.empty((n_queries, k), dtype=int)

    for i in range(n_queries):
        best_k = []
        self._search(self.root, x_arr[i], k, best_k)
        nearest_idx[i] = [idx for _, idx in best_k]

    if single_query:
        return nearest_idx[0]
    return nearest_idx