Skip to content

glassbox.models.trees._base

Abstract BaseTree with recursive tree building and traversal logic.


BaseTree

BaseTree(max_depth=100, min_samples_split=2)

Bases: BaseModel

Abstract base class for all tree-based models.

Initialize the base tree model.

Parameters:

Name Type Description Default
max_depth int

Maximum depth of the tree.

100
min_samples_split int

Minimum number of samples required to split an internal node.

2
Source code in glassbox/models/trees/_base.py
def __init__(self, max_depth: int = 100, min_samples_split: int = 2) -> None:
    """
    Initialize the base tree model.

    Parameters
    ----------
    max_depth : int, default=100
        Maximum depth of the tree.
    min_samples_split : int, default=2
        Minimum number of samples required to split an internal node.
    """
    self.max_depth = max_depth if max_depth is not None else float("inf")
    self.min_samples_split = min_samples_split
    self.root: Optional[_Node] = None

fit

fit(X, y)

Fits the tree model to the training data.

Parameters:

Name Type Description Default
X ndarray

Training data of shape (n_samples, n_features).

required
y ndarray

Target values of shape (n_samples,).

required

Returns:

Type Description
Self

The fitted model.

Source code in glassbox/models/trees/_base.py
def fit(self, X: np.ndarray, y: np.ndarray) -> Self:
    """
    Fits the tree model to the training data.

    Parameters
    ----------
    X : np.ndarray
        Training data of shape (n_samples, n_features).
    y : np.ndarray
        Target values of shape (n_samples,).

    Returns
    -------
    Self
        The fitted model.
    """
    self.root = self._build_tree(X, y, depth=0)
    return self

predict

predict(X, **kwargs)

Predicts target values for the given data.

Parameters:

Name Type Description Default
X ndarray

Data to predict on, of shape (n_samples, n_features).

required
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
ndarray

Predicted target values.

Source code in glassbox/models/trees/_base.py
def predict(self, X: np.ndarray, **kwargs: Any) -> np.ndarray:
    """
    Predicts target values for the given data.

    Parameters
    ----------
    X : np.ndarray
        Data to predict on, of shape (n_samples, n_features).
    **kwargs : Any
        Additional keyword arguments.

    Returns
    -------
    np.ndarray
        Predicted target values.
    """
    if self.root is None:
        raise RuntimeError("Model is not fitted yet.")
    return np.array([self._traverse_tree(x, self.root) for x in X])