Skip to content

glassbox.models.neighbors.index

Spatial index implementations for KNN search.


BaseIndex

BaseIndex(metric)

Bases: ABC

Initialize the abstract index class.

Parameters:

Name Type Description Default
metric DistanceMetric

The distance metric to use.

required
Source code in glassbox/models/neighbors/index/_base.py
def __init__(self, metric: DistanceMetric) -> None:
    """
    Initialize the abstract index class.

    Parameters
    ----------
    metric : DistanceMetric
        The distance metric to use.
    """
    self.metric: DistanceMetric = metric

build abstractmethod

build(X)

Build the index structure from the training data.

Parameters:

Name Type Description Default
X ndarray

Training data, shape (n_samples, n_features).

required
Source code in glassbox/models/neighbors/index/_base.py
@abstractmethod
def build(self, X: np.ndarray) -> None:
    """
    Build the index structure from the training data.

    Parameters
    ----------
    X : np.ndarray
        Training data, shape (n_samples, n_features).
    """
    raise NotImplementedError

query abstractmethod

query(x, k)

Query the index for the k nearest neighbors.

Parameters:

Name Type Description Default
x ndarray

Query point(s), shape (n_features,) or (n_queries, n_features).

required
k int

Number of nearest neighbors to retrieve.

required

Returns:

Type Description
ndarray

Indices of the k nearest neighbors.

Source code in glassbox/models/neighbors/index/_base.py
@abstractmethod
def query(self, x: np.ndarray, k: int) -> np.ndarray:
    """
    Query the index for the k nearest neighbors.

    Parameters
    ----------
    x : np.ndarray
        Query point(s), shape (n_features,) or (n_queries, n_features).
    k : int
        Number of nearest neighbors to retrieve.

    Returns
    -------
    np.ndarray
        Indices of the k nearest neighbors.
    """
    raise NotImplementedError

BruteForceIndex

BruteForceIndex(metric)

Bases: BaseIndex

Initialize the BruteForceIndex.

Parameters:

Name Type Description Default
metric DistanceMetric

The distance metric to use for querying.

required
Source code in glassbox/models/neighbors/index/_brute.py
def __init__(self, metric: DistanceMetric) -> None:
    """
    Initialize the BruteForceIndex.

    Parameters
    ----------
    metric : DistanceMetric
        The distance metric to use for querying.
    """
    super().__init__(metric)
    self.X_train: np.ndarray | None = None

build

build(X)

Build the brute-force index by storing the training data.

Parameters:

Name Type Description Default
X ndarray

Training data, shape (n_samples, n_features).

required
Source code in glassbox/models/neighbors/index/_brute.py
def build(self, X: np.ndarray) -> None:
    """
    Build the brute-force index by storing the training data.

    Parameters
    ----------
    X : np.ndarray
        Training data, shape (n_samples, n_features).
    """
    self.X_train = np.asarray(X, dtype=float)

query

query(x, k)

Query the index for the k nearest neighbors using exhaustive search.

Parameters:

Name Type Description Default
x ndarray

Query point(s), shape (n_features,) or (n_queries, n_features).

required
k int

Number of nearest neighbors to retrieve.

required

Returns:

Type Description
ndarray

Indices of the k nearest neighbors.

Source code in glassbox/models/neighbors/index/_brute.py
def query(self, x: np.ndarray, k: int) -> np.ndarray:
    """
    Query the index for the k nearest neighbors using exhaustive search.

    Parameters
    ----------
    x : np.ndarray
        Query point(s), shape (n_features,) or (n_queries, n_features).
    k : int
        Number of nearest neighbors to retrieve.

    Returns
    -------
    np.ndarray
        Indices of the k nearest neighbors.
    """
    if self.X_train is None:
        raise ValueError("Index is not built yet")

    x_arr = np.asarray(x, dtype=float)
    single_query = x_arr.ndim == 1
    if single_query:
        x_arr = x_arr.reshape(1, -1)

    n_queries = x_arr.shape[0]
    nearest_idx = np.empty((n_queries, k), dtype=int)

    for i in range(n_queries):
        diff = self.X_train - x_arr[i]
        if self.metric == DistanceMetric.EUCLIDEAN:
            dist = np.sqrt(np.sum(diff**2, axis=1))
        elif self.metric == DistanceMetric.MANHATTAN:
            dist = np.sum(np.abs(diff), axis=1)
        else:
            raise ValueError(f"Unsupported metric: {self.metric}")

        nearest_idx[i] = np.argsort(dist)[:k]

    if single_query:
        return nearest_idx[0]
    return nearest_idx

KDTreeIndex

KDTreeIndex(metric, leaf_size=30)

Bases: BaseIndex

Initialize the KDTreeIndex.

Parameters:

Name Type Description Default
metric DistanceMetric

The distance metric to use.

required
leaf_size int

Number of points at which to switch to brute-force.

30
Source code in glassbox/models/neighbors/index/_kdtree.py
def __init__(self, metric: DistanceMetric, leaf_size: int = 30) -> None:
    """
    Initialize the KDTreeIndex.

    Parameters
    ----------
    metric : DistanceMetric
        The distance metric to use.
    leaf_size : int, default=30
        Number of points at which to switch to brute-force.
    """
    super().__init__(metric)
    self.root: KDNode | None = None
    self.leaf_size: int = leaf_size
    self.X_train: np.ndarray | None = None

build

build(X)

Build the KD-Tree structure from the training data.

Parameters:

Name Type Description Default
X ndarray

Training data, shape (n_samples, n_features).

required
Source code in glassbox/models/neighbors/index/_kdtree.py
def build(self, X: np.ndarray) -> None:
    """
    Build the KD-Tree structure from the training data.

    Parameters
    ----------
    X : np.ndarray
        Training data, shape (n_samples, n_features).
    """
    self.X_train = np.asarray(X, dtype=float)
    indices = np.arange(len(self.X_train))
    self.root = self._build_tree(indices, depth=0)

query

query(x, k)

Query the KD-Tree for the k nearest neighbors.

Parameters:

Name Type Description Default
x ndarray

Query point(s), shape (n_features,) or (n_queries, n_features).

required
k int

Number of nearest neighbors to retrieve.

required

Returns:

Type Description
ndarray

Indices of the k nearest neighbors.

Source code in glassbox/models/neighbors/index/_kdtree.py
def query(self, x: np.ndarray, k: int) -> np.ndarray:
    """
    Query the KD-Tree for the k nearest neighbors.

    Parameters
    ----------
    x : np.ndarray
        Query point(s), shape (n_features,) or (n_queries, n_features).
    k : int
        Number of nearest neighbors to retrieve.

    Returns
    -------
    np.ndarray
        Indices of the k nearest neighbors.
    """
    if self.X_train is None or self.root is None:
        raise ValueError("Index is not built yet")

    x_arr = np.asarray(x, dtype=float)
    single_query = x_arr.ndim == 1
    if single_query:
        x_arr = x_arr.reshape(1, -1)

    n_queries = x_arr.shape[0]
    nearest_idx = np.empty((n_queries, k), dtype=int)

    for i in range(n_queries):
        best_k = []
        self._search(self.root, x_arr[i], k, best_k)
        nearest_idx[i] = [idx for _, idx in best_k]

    if single_query:
        return nearest_idx[0]
    return nearest_idx