Skip to content

Model Estimators

centimators.model_estimators.keras_estimators

Keras-based model estimators with scikit-learn compatible API.

Organized by architectural family
  • base: BaseKerasEstimator and shared utilities
  • dense: Simple feedforward networks (MLPRegressor)
  • autoencoder: Reconstruction-based architectures (BottleneckEncoder)
  • sequence: Sequence models for temporal data (SequenceEstimator, LSTMRegressor)

BaseKerasEstimator dataclass

Bases: TransformerMixin, BaseEstimator, ABC

Meta-estimator for Keras models following the scikit-learn API.

Parameters

output_units : int, default=1 Number of output units in the final layer. optimizer : Type[optimizers.Optimizer], default=Adam Keras optimizer class to use for training. learning_rate : float, default=0.001 Learning rate for the optimizer. loss_function : str, default="mse" Loss function name passed to model.compile(). metrics : list[str] | None, default=None List of metric names to track during training. model : Any, default=None The underlying Keras model (populated by build_model). distribution_strategy : str | None, default=None If set, enables DataParallel distribution for multi-device training. target_scaler : sklearn transformer | None, default=None Scaler for target values. Neural networks converge better when targets are normalized. Subclasses may override the default (e.g., regressors default to StandardScaler).

Source code in src/centimators/model_estimators/keras_estimators/base.py
@dataclass(kw_only=True)
class BaseKerasEstimator(TransformerMixin, BaseEstimator, ABC):
    """Meta-estimator for Keras models following the scikit-learn API.

    Parameters
    ----------
    output_units : int, default=1
        Number of output units in the final layer.
    optimizer : Type[optimizers.Optimizer], default=Adam
        Keras optimizer class to use for training.
    learning_rate : float, default=0.001
        Learning rate for the optimizer.
    loss_function : str, default="mse"
        Loss function name passed to model.compile().
    metrics : list[str] | None, default=None
        List of metric names to track during training.
    model : Any, default=None
        The underlying Keras model (populated by build_model).
    distribution_strategy : str | None, default=None
        If set, enables DataParallel distribution for multi-device training.
    target_scaler : sklearn transformer | None, default=None
        Scaler for target values. Neural networks converge better when
        targets are normalized. Subclasses may override the default
        (e.g., regressors default to StandardScaler).
    """

    output_units: int = 1
    optimizer: Type[optimizers.Optimizer] = optimizers.Adam
    learning_rate: float = 0.001
    loss_function: str = "mse"
    metrics: list[str] | None = None
    model: Any = None
    distribution_strategy: str | None = None
    target_scaler: Any = None

    @abstractmethod
    def build_model(self):
        pass

    def _setup_distribution_strategy(self) -> None:
        strategy = distribution.DataParallel()
        distribution.set_distribution(strategy)

    def fit(
        self,
        X,
        y,
        epochs: int = 100,
        batch_size: int = 32,
        validation_data: tuple[Any, Any] | None = None,
        callbacks: list[Any] | None = None,
        **kwargs: Any,
    ) -> "BaseKerasEstimator":
        self._n_features_in_ = X.shape[1]

        if self.distribution_strategy:
            self._setup_distribution_strategy()

        # Convert inputs to numpy
        X_np = _ensure_numpy(X)
        y_np = _ensure_numpy(y, allow_series=True)

        # Ensure y is 2D for scaler
        y_was_1d = y_np.ndim == 1
        if y_was_1d:
            y_np = y_np.reshape(-1, 1)

        # Scale targets for better neural network convergence
        if self.target_scaler:
            y_np = self.target_scaler.fit_transform(y_np).astype("float32")

            # Scale validation targets too
            if validation_data is not None:
                val_X, val_y = validation_data
                val_y_np = _ensure_numpy(val_y, allow_series=True)
                if val_y_np.ndim == 1:
                    val_y_np = val_y_np.reshape(-1, 1)
                val_y_scaled = self.target_scaler.transform(val_y_np).astype("float32")
                validation_data = (_ensure_numpy(val_X), val_y_scaled)

        if not self.model:
            self.build_model()

        self.model.fit(
            X_np,
            y=y_np,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=validation_data,
            callbacks=callbacks,
            **kwargs,
        )
        self._is_fitted = True
        return self

    @nw.narwhalify
    def predict(self, X, batch_size: int = 512, **kwargs: Any) -> Any:
        if not self.model:
            raise ValueError("Model not built. Call `build_model` first.")

        predictions = self.model.predict(
            _ensure_numpy(X), batch_size=batch_size, **kwargs
        )

        # Inverse transform predictions back to original scale
        if self.target_scaler:
            predictions = self.target_scaler.inverse_transform(predictions)

        # Return numpy arrays for numpy input
        if isinstance(X, numpy.ndarray):
            return predictions

        # Return dataframe for dataframe input
        if predictions.ndim == 1:
            return nw.from_dict(
                {"prediction": predictions}, backend=nw.get_native_namespace(X)
            )
        elif predictions.shape[1] == 1:
            return nw.from_dict(
                {"prediction": predictions[:, 0]}, backend=nw.get_native_namespace(X)
            )
        else:
            cols = {
                f"prediction_{i}": predictions[:, i]
                for i in range(predictions.shape[1])
            }
            return nw.from_dict(cols, backend=nw.get_native_namespace(X))

    def transform(self, X, **kwargs):
        return self.predict(X, **kwargs)

    def __sklearn_is_fitted__(self) -> bool:
        return getattr(self, "_is_fitted", False)

MLPRegressor dataclass

Bases: RegressorMixin, BaseKerasEstimator

A minimal fully-connected multi-layer perceptron for tabular data.

Source code in src/centimators/model_estimators/keras_estimators/dense.py
@dataclass(kw_only=True)
class MLPRegressor(RegressorMixin, BaseKerasEstimator):
    """A minimal fully-connected multi-layer perceptron for tabular data."""

    hidden_units: tuple[int, ...] = (64, 64)
    activation: str = "relu"
    dropout_rate: float = 0.0
    metrics: list[str] | None = field(default_factory=lambda: ["mse"])
    target_scaler: Any = field(default_factory=StandardScaler)

    def build_model(self):
        inputs = layers.Input(shape=(self._n_features_in_,), name="features")
        x = inputs
        for units in self.hidden_units:
            x = layers.Dense(units, activation=self.activation)(x)
            if self.dropout_rate > 0:
                x = layers.Dropout(self.dropout_rate)(x)
        outputs = layers.Dense(self.output_units, activation="linear")(x)
        self.model = models.Model(inputs=inputs, outputs=outputs, name="mlp_regressor")

        self.model.compile(
            optimizer=self.optimizer(learning_rate=self.learning_rate),
            loss=self.loss_function,
            metrics=self.metrics,
        )
        return self

BottleneckEncoder dataclass

Bases: BaseKerasEstimator

A bottleneck autoencoder that can learn latent representations and predict targets.

Source code in src/centimators/model_estimators/keras_estimators/autoencoder.py
@dataclass(kw_only=True)
class BottleneckEncoder(BaseKerasEstimator):
    """A bottleneck autoencoder that can learn latent representations and predict targets."""

    gaussian_noise: float = 0.035
    encoder_units: list[tuple[int, float]] = field(
        default_factory=lambda: [(1024, 0.1)]
    )
    latent_units: tuple[int, float] = (256, 0.1)
    ae_units: list[tuple[int, float]] = field(default_factory=lambda: [(96, 0.4)])
    activation: str = "swish"
    reconstruction_loss_weight: float = 1.0
    target_loss_weight: float = 1.0
    encoder: Any = None

    def build_model(self):
        if self._n_features_in_ is None:
            raise ValueError("Must call fit() before building the model")

        inputs = layers.Input(shape=(self._n_features_in_,), name="features")
        x0 = layers.BatchNormalization()(inputs)

        encoder = layers.GaussianNoise(self.gaussian_noise)(x0)
        for units, dropout in self.encoder_units:
            encoder = layers.Dense(units)(encoder)
            encoder = layers.BatchNormalization()(encoder)
            encoder = layers.Activation(self.activation)(encoder)
            encoder = layers.Dropout(dropout)(encoder)

        latent_units, latent_dropout = self.latent_units
        latent = layers.Dense(latent_units)(encoder)
        latent = layers.BatchNormalization()(latent)
        latent = layers.Activation(self.activation)(latent)
        latent_output = layers.Dropout(latent_dropout)(latent)

        self.encoder = models.Model(
            inputs=inputs, outputs=latent_output, name="encoder"
        )

        decoder = latent_output
        for units, dropout in reversed(self.encoder_units):
            decoder = layers.Dense(units)(decoder)
            decoder = layers.BatchNormalization()(decoder)
            decoder = layers.Activation(self.activation)(decoder)
            decoder = layers.Dropout(dropout)(decoder)

        reconstruction = layers.Dense(self._n_features_in_, name="reconstruction")(
            decoder
        )

        target_pred = reconstruction
        for units, dropout in self.ae_units:
            target_pred = layers.Dense(units)(target_pred)
            target_pred = layers.BatchNormalization()(target_pred)
            target_pred = layers.Activation(self.activation)(target_pred)
            target_pred = layers.Dropout(dropout)(target_pred)

        target_output = layers.Dense(
            self.output_units, activation="linear", name="target_prediction"
        )(target_pred)

        self.model = models.Model(
            inputs=inputs,
            outputs=[reconstruction, target_output],
            name="bottleneck_encoder",
        )

        self.model.compile(
            optimizer=self.optimizer(learning_rate=self.learning_rate),
            loss={"reconstruction": "mse", "target_prediction": self.loss_function},
            loss_weights={
                "reconstruction": self.reconstruction_loss_weight,
                "target_prediction": self.target_loss_weight,
            },
            metrics={"target_prediction": self.metrics or ["mse"]},
        )
        return self

    def fit(
        self,
        X,
        y,
        epochs: int = 100,
        batch_size: int = 32,
        validation_data: tuple[Any, Any] | None = None,
        callbacks: list[Any] | None = None,
        **kwargs: Any,
    ) -> "BottleneckEncoder":
        self._n_features_in_ = X.shape[1]

        if self.distribution_strategy:
            self._setup_distribution_strategy()

        if not self.model:
            self.build_model()

        X_np = _ensure_numpy(X)
        y_np = _ensure_numpy(y, allow_series=True)

        y_dict = {"reconstruction": X_np, "target_prediction": y_np}

        if validation_data is not None:
            X_val, y_val = validation_data
            X_val_np = _ensure_numpy(X_val)
            y_val_np = _ensure_numpy(y_val, allow_series=True)
            validation_data = (
                X_val_np,
                {"reconstruction": X_val_np, "target_prediction": y_val_np},
            )

        self.model.fit(
            X_np,
            y_dict,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=validation_data,
            callbacks=callbacks,
            **kwargs,
        )

        self._is_fitted = True
        return self

    def predict(self, X, batch_size: int = 512, **kwargs: Any) -> Any:
        if not self.model:
            raise ValueError("Model not built. Call 'fit' first.")
        X_np = _ensure_numpy(X)
        predictions = self.model.predict(X_np, batch_size=batch_size, **kwargs)
        return predictions[1] if isinstance(predictions, list) else predictions

    def transform(self, X, batch_size: int = 512, **kwargs: Any) -> Any:
        if not self.encoder:
            raise ValueError("Encoder not built. Call 'fit' first.")
        X_np = _ensure_numpy(X)
        return self.encoder.predict(X_np, batch_size=batch_size, **kwargs)

    def fit_transform(self, X, y, **kwargs) -> Any:
        return self.fit(X, y, **kwargs).transform(X)

    def get_feature_names_out(self, input_features=None) -> list[str]:
        latent_dim = self.latent_units[0]
        return [f"latent_{i}" for i in range(latent_dim)]

SequenceEstimator dataclass

Bases: BaseKerasEstimator

Estimator for models that consume sequential data.

Source code in src/centimators/model_estimators/keras_estimators/sequence.py
@dataclass(kw_only=True)
class SequenceEstimator(BaseKerasEstimator):
    """Estimator for models that consume sequential data."""

    lag_windows: list[int]
    n_features_per_timestep: int

    def __post_init__(self):
        self.seq_length = len(self.lag_windows)

    def _reshape(self, X: IntoFrame, validation_data: tuple[Any, Any] | None = None):
        X = _ensure_numpy(X)
        X_reshaped = ops.reshape(
            X, (X.shape[0], self.seq_length, self.n_features_per_timestep)
        )

        if validation_data:
            X_val, y_val = validation_data
            X_val = _ensure_numpy(X_val)
            X_val_reshaped = ops.reshape(
                X_val,
                (X_val.shape[0], self.seq_length, self.n_features_per_timestep),
            )
            validation_data = X_val_reshaped, _ensure_numpy(y_val)

        return X_reshaped, validation_data

    def fit(
        self, X, y, validation_data: tuple[Any, Any] | None = None, **kwargs: Any
    ) -> "SequenceEstimator":
        X_reshaped, validation_data_reshaped = self._reshape(X, validation_data)
        super().fit(
            X_reshaped,
            y=_ensure_numpy(y),
            validation_data=validation_data_reshaped,
            **kwargs,
        )
        return self

    @nw.narwhalify
    def predict(self, X, batch_size: int = 512, **kwargs: Any) -> Any:
        if not self.model:
            raise ValueError("Model not built. Call `build_model` first.")

        # Store original X for backend detection before reshaping
        X_original = X
        X_reshaped, _ = self._reshape(X)

        predictions = self.model.predict(
            _ensure_numpy(X_reshaped), batch_size=batch_size, **kwargs
        )

        # Inverse transform predictions back to original scale
        if self.target_scaler:
            predictions = self.target_scaler.inverse_transform(predictions)

        # Use X_original (not X_reshaped) for backend detection
        if isinstance(X_original, numpy.ndarray):
            return predictions

        if predictions.ndim == 1:
            return nw.from_dict(
                {"prediction": predictions}, backend=nw.get_native_namespace(X_original)
            )
        else:
            cols = {
                f"prediction_{i}": predictions[:, i]
                for i in range(predictions.shape[1])
            }
            return nw.from_dict(cols, backend=nw.get_native_namespace(X_original))

LSTMRegressor dataclass

Bases: RegressorMixin, SequenceEstimator

LSTM-based regressor for sequence prediction.

Source code in src/centimators/model_estimators/keras_estimators/sequence.py
@dataclass(kw_only=True)
class LSTMRegressor(RegressorMixin, SequenceEstimator):
    """LSTM-based regressor for sequence prediction."""

    lstm_units: list[tuple[int, float, float]] = field(
        default_factory=lambda: [(64, 0.01, 0.01)]
    )
    use_batch_norm: bool = False
    use_layer_norm: bool = False
    bidirectional: bool = False
    metrics: list[str] | None = field(default_factory=lambda: ["mse"])
    target_scaler: Any = field(default_factory=StandardScaler)

    def build_model(self):
        if self._n_features_in_ is None:
            raise ValueError("Must call fit() before building the model")

        inputs = layers.Input(
            shape=(self.seq_length, self.n_features_per_timestep), name="sequence_input"
        )
        x = inputs

        for layer_num, (units, dropout, recurrent_dropout) in enumerate(
            self.lstm_units
        ):
            return_sequences = layer_num < len(self.lstm_units) - 1
            lstm_layer = layers.LSTM(
                units=units,
                activation="tanh",
                return_sequences=return_sequences,
                dropout=dropout,
                recurrent_dropout=recurrent_dropout,
                name=f"lstm_{layer_num}",
            )
            if self.bidirectional:
                x = layers.Bidirectional(lstm_layer, name=f"bidirectional_{layer_num}")(
                    x
                )
            else:
                x = lstm_layer(x)
            if self.use_layer_norm:
                x = layers.LayerNormalization(name=f"layer_norm_{layer_num}")(x)
            if self.use_batch_norm:
                x = layers.BatchNormalization(name=f"batch_norm_{layer_num}")(x)

        outputs = layers.Dense(self.output_units, activation="linear", name="output")(x)
        self.model = models.Model(inputs=inputs, outputs=outputs, name="lstm_regressor")
        self.model.compile(
            optimizer=self.optimizer(learning_rate=self.learning_rate),
            loss=self.loss_function,
            metrics=self.metrics,
        )
        return self

NeuralDecisionForestRegressor dataclass

Bases: RegressorMixin, BaseKerasEstimator

Neural Decision Forest regressor with differentiable tree ensembles.

A Neural Decision Forest is an ensemble of differentiable decision trees trained end-to-end via gradient descent. Each tree uses stochastic routing where internal nodes learn probability distributions over routing decisions. The forest combines predictions by averaging over all trees.

This architecture provides: - Interpretable tree-like structure with learned routing - Feature bagging via used_features_rate (like random forests) - End-to-end differentiable training - Ensemble averaging for improved generalization - Temperature-controlled routing sharpness - Input noise, per-tree noise, and tree dropout for ensemble diversity

Parameters

num_trees : int, default=25 Number of decision trees in the forest ensemble. depth : int, default=4 Depth of each tree. Each tree will have 2^depth leaf nodes. Deeper trees have more capacity but harder gradient flow. used_features_rate : float, default=0.5 Fraction of features each tree randomly selects (0 to 1). Provides feature bagging. Lower values increase diversity. l2_decision : float, default=1e-4 L2 regularization for routing decision layers. Lower values allow sharper routing decisions. l2_leaf : float, default=1e-3 L2 regularization for leaf output weights. Can be stronger than l2_decision since leaves are regression weights. temperature : float, default=0.5 Temperature for sigmoid sharpness in routing. Lower values (0.3-0.5) give sharper, more tree-like routing. Higher values (1-3) give softer routing where samples flow through multiple paths. input_noise_std : float, default=0.0 Gaussian noise std applied to inputs before trunk. Makes trunk robust to input perturbations. Try 0.02-0.05. tree_noise_std : float, default=0.0 Gaussian noise std applied per-tree after trunk. Each tree sees a different noisy view, decorrelating the ensemble. Try 0.03-0.1. tree_dropout_rate : float, default=0.0 Dropout rate for tree outputs during training (0 to 1). Randomly drops tree contributions to decorrelate ensemble. trunk_units : list[int] | None, default=None Hidden layer sizes for optional shared MLP trunk before trees. E.g. [64, 64] adds two Dense+ReLU layers. Trees then split on learned features instead of raw columns. random_state : int | None, default=None Random seed for reproducible feature mask sampling across trees. output_units : int, default=1 Number of output targets to predict. optimizer : Type[keras.optimizers.Optimizer], default=Adam Keras optimizer class to use for training. learning_rate : float, default=0.001 Learning rate for the optimizer. loss_function : str, default="mse" Loss function for training. metrics : list[str] | None, default=None List of metrics to track during training. distribution_strategy : str | None, default=None Distribution strategy for multi-device training.

Attributes

model : keras.Model The compiled Keras model containing the ensemble of trees. trees : list[NeuralDecisionTree] List of tree models in the ensemble.

Examples

from centimators.model_estimators import NeuralDecisionForestRegressor import numpy as np X = np.random.randn(100, 10).astype('float32') y = np.random.randn(100, 1).astype('float32') ndf = NeuralDecisionForestRegressor(num_trees=5, depth=4) ndf.fit(X, y, epochs=10, verbose=0) predictions = ndf.predict(X)

Notes
  • Larger depth increases model capacity but may lead to overfitting
  • More trees generally improve performance but increase computation
  • Lower used_features_rate increases diversity but may hurt individual tree performance
  • Works well on tabular data where tree-based methods traditionally excel
  • Lower temperature (0.3-0.5) gives sharper, more tree-like routing
References

The approach is based on Neural Decision Forests and related differentiable tree architectures that enable end-to-end learning of routing decisions.

Source code in src/centimators/model_estimators/keras_estimators/tree.py
@dataclass(kw_only=True)
class NeuralDecisionForestRegressor(RegressorMixin, BaseKerasEstimator):
    """Neural Decision Forest regressor with differentiable tree ensembles.

    A Neural Decision Forest is an ensemble of differentiable decision trees
    trained end-to-end via gradient descent. Each tree uses stochastic routing
    where internal nodes learn probability distributions over routing decisions.
    The forest combines predictions by averaging over all trees.

    This architecture provides:
    - Interpretable tree-like structure with learned routing
    - Feature bagging via used_features_rate (like random forests)
    - End-to-end differentiable training
    - Ensemble averaging for improved generalization
    - Temperature-controlled routing sharpness
    - Input noise, per-tree noise, and tree dropout for ensemble diversity

    Parameters
    ----------
    num_trees : int, default=25
        Number of decision trees in the forest ensemble.
    depth : int, default=4
        Depth of each tree. Each tree will have 2^depth leaf nodes.
        Deeper trees have more capacity but harder gradient flow.
    used_features_rate : float, default=0.5
        Fraction of features each tree randomly selects (0 to 1).
        Provides feature bagging. Lower values increase diversity.
    l2_decision : float, default=1e-4
        L2 regularization for routing decision layers.
        Lower values allow sharper routing decisions.
    l2_leaf : float, default=1e-3
        L2 regularization for leaf output weights.
        Can be stronger than l2_decision since leaves are regression weights.
    temperature : float, default=0.5
        Temperature for sigmoid sharpness in routing. Lower values (0.3-0.5)
        give sharper, more tree-like routing. Higher values (1-3) give softer
        routing where samples flow through multiple paths.
    input_noise_std : float, default=0.0
        Gaussian noise std applied to inputs before trunk.
        Makes trunk robust to input perturbations. Try 0.02-0.05.
    tree_noise_std : float, default=0.0
        Gaussian noise std applied per-tree after trunk.
        Each tree sees a different noisy view, decorrelating the ensemble.
        Try 0.03-0.1.
    tree_dropout_rate : float, default=0.0
        Dropout rate for tree outputs during training (0 to 1).
        Randomly drops tree contributions to decorrelate ensemble.
    trunk_units : list[int] | None, default=None
        Hidden layer sizes for optional shared MLP trunk before trees.
        E.g. [64, 64] adds two Dense+ReLU layers. Trees then split on
        learned features instead of raw columns.
    random_state : int | None, default=None
        Random seed for reproducible feature mask sampling across trees.
    output_units : int, default=1
        Number of output targets to predict.
    optimizer : Type[keras.optimizers.Optimizer], default=Adam
        Keras optimizer class to use for training.
    learning_rate : float, default=0.001
        Learning rate for the optimizer.
    loss_function : str, default="mse"
        Loss function for training.
    metrics : list[str] | None, default=None
        List of metrics to track during training.
    distribution_strategy : str | None, default=None
        Distribution strategy for multi-device training.

    Attributes
    ----------
    model : keras.Model
        The compiled Keras model containing the ensemble of trees.
    trees : list[NeuralDecisionTree]
        List of tree models in the ensemble.

    Examples
    --------
    >>> from centimators.model_estimators import NeuralDecisionForestRegressor
    >>> import numpy as np
    >>> X = np.random.randn(100, 10).astype('float32')
    >>> y = np.random.randn(100, 1).astype('float32')
    >>> ndf = NeuralDecisionForestRegressor(num_trees=5, depth=4)
    >>> ndf.fit(X, y, epochs=10, verbose=0)
    >>> predictions = ndf.predict(X)

    Notes
    -----
    - Larger depth increases model capacity but may lead to overfitting
    - More trees generally improve performance but increase computation
    - Lower used_features_rate increases diversity but may hurt individual tree performance
    - Works well on tabular data where tree-based methods traditionally excel
    - Lower temperature (0.3-0.5) gives sharper, more tree-like routing

    References
    ----------
    The approach is based on Neural Decision Forests and related differentiable
    tree architectures that enable end-to-end learning of routing decisions.
    """

    num_trees: int = 25
    depth: int = 4
    used_features_rate: float = 0.5
    l2_decision: float = 1e-4
    l2_leaf: float = 1e-3
    temperature: float = 0.5
    input_noise_std: float = 0.0
    tree_noise_std: float = 0.0
    tree_dropout_rate: float = 0.0
    trunk_units: list[int] | None = None
    random_state: int | None = None
    metrics: list[str] | None = field(default_factory=lambda: ["mse"])
    target_scaler: Any = field(default_factory=StandardScaler)

    def __post_init__(self):
        self.trees: list[NeuralDecisionTree] = []

    def build_model(self):
        """Build the neural decision forest model.

        Creates an ensemble of NeuralDecisionTree models with shared input
        and averaged output. Each tree receives normalized input features
        via BatchNormalization. Optionally includes input noise (before trunk
        for robustness), per-tree noise (for diversity), tree dropout, and
        a shared MLP trunk.

        Returns
        -------
        self
            Returns self for method chaining.
        """
        if self.model is None:
            if self.distribution_strategy:
                self._setup_distribution_strategy()

            # Set up RNG for reproducibility
            rng = np.random.default_rng(self.random_state)

            # Input layer
            inputs = layers.Input(shape=(self._n_features_in_,))
            x = layers.BatchNormalization()(inputs)

            # Input noise before trunk (makes trunk robust to perturbations)
            if self.input_noise_std > 0:
                x = layers.GaussianNoise(self.input_noise_std)(x)

            # Optional shared trunk (MLP before trees)
            if self.trunk_units:
                for units in self.trunk_units:
                    x = layers.Dense(units, activation="relu")(x)

            # Determine feature count for trees (trunk output or raw features)
            tree_num_features = (
                self.trunk_units[-1] if self.trunk_units else self._n_features_in_
            )

            # Create ensemble of trees
            self.trees = []
            for _ in range(self.num_trees):
                tree = NeuralDecisionTree(
                    depth=self.depth,
                    num_features=tree_num_features,
                    used_features_rate=self.used_features_rate,
                    output_units=self.output_units,
                    l2_decision=self.l2_decision,
                    l2_leaf=self.l2_leaf,
                    temperature=self.temperature,
                    rng=rng,
                )
                self.trees.append(tree)

            # each tree gets its own noisy view for diversity
            tree_outputs = []
            for tree in self.trees:
                if self.tree_noise_std > 0:
                    noisy_x = layers.GaussianNoise(self.tree_noise_std)(x)
                    tree_outputs.append(tree(noisy_x))
                else:
                    tree_outputs.append(tree(x))

            if len(tree_outputs) > 1:
                stacked = K.stack(tree_outputs, axis=1)  # [batch, num_trees, out_units]
                if self.tree_dropout_rate > 0:
                    # Drop entire trees
                    stacked = layers.Dropout(
                        self.tree_dropout_rate,
                        noise_shape=(
                            None,
                            self.num_trees,
                            1,
                        ),  # broadcasts so whole tree is dropped
                    )(stacked)
                outputs = K.mean(stacked, axis=1)
            else:
                outputs = tree_outputs[0]

            self.model = models.Model(inputs=inputs, outputs=outputs)
            opt = self.optimizer(learning_rate=self.learning_rate)
            self.model.compile(
                optimizer=opt, loss=self.loss_function, metrics=self.metrics
            )
        return self

build_model()

Build the neural decision forest model.

Creates an ensemble of NeuralDecisionTree models with shared input and averaged output. Each tree receives normalized input features via BatchNormalization. Optionally includes input noise (before trunk for robustness), per-tree noise (for diversity), tree dropout, and a shared MLP trunk.

Returns

self Returns self for method chaining.

Source code in src/centimators/model_estimators/keras_estimators/tree.py
def build_model(self):
    """Build the neural decision forest model.

    Creates an ensemble of NeuralDecisionTree models with shared input
    and averaged output. Each tree receives normalized input features
    via BatchNormalization. Optionally includes input noise (before trunk
    for robustness), per-tree noise (for diversity), tree dropout, and
    a shared MLP trunk.

    Returns
    -------
    self
        Returns self for method chaining.
    """
    if self.model is None:
        if self.distribution_strategy:
            self._setup_distribution_strategy()

        # Set up RNG for reproducibility
        rng = np.random.default_rng(self.random_state)

        # Input layer
        inputs = layers.Input(shape=(self._n_features_in_,))
        x = layers.BatchNormalization()(inputs)

        # Input noise before trunk (makes trunk robust to perturbations)
        if self.input_noise_std > 0:
            x = layers.GaussianNoise(self.input_noise_std)(x)

        # Optional shared trunk (MLP before trees)
        if self.trunk_units:
            for units in self.trunk_units:
                x = layers.Dense(units, activation="relu")(x)

        # Determine feature count for trees (trunk output or raw features)
        tree_num_features = (
            self.trunk_units[-1] if self.trunk_units else self._n_features_in_
        )

        # Create ensemble of trees
        self.trees = []
        for _ in range(self.num_trees):
            tree = NeuralDecisionTree(
                depth=self.depth,
                num_features=tree_num_features,
                used_features_rate=self.used_features_rate,
                output_units=self.output_units,
                l2_decision=self.l2_decision,
                l2_leaf=self.l2_leaf,
                temperature=self.temperature,
                rng=rng,
            )
            self.trees.append(tree)

        # each tree gets its own noisy view for diversity
        tree_outputs = []
        for tree in self.trees:
            if self.tree_noise_std > 0:
                noisy_x = layers.GaussianNoise(self.tree_noise_std)(x)
                tree_outputs.append(tree(noisy_x))
            else:
                tree_outputs.append(tree(x))

        if len(tree_outputs) > 1:
            stacked = K.stack(tree_outputs, axis=1)  # [batch, num_trees, out_units]
            if self.tree_dropout_rate > 0:
                # Drop entire trees
                stacked = layers.Dropout(
                    self.tree_dropout_rate,
                    noise_shape=(
                        None,
                        self.num_trees,
                        1,
                    ),  # broadcasts so whole tree is dropped
                )(stacked)
            outputs = K.mean(stacked, axis=1)
        else:
            outputs = tree_outputs[0]

        self.model = models.Model(inputs=inputs, outputs=outputs)
        opt = self.optimizer(learning_rate=self.learning_rate)
        self.model.compile(
            optimizer=opt, loss=self.loss_function, metrics=self.metrics
        )
    return self

TemperatureAnnealing

Bases: Callback

Anneal tree routing temperature from soft to sharp over training.

Starts with high temperature (soft routing, samples flow through many paths) and linearly decreases to low temperature (sharp routing, more tree-like). This can theoretically help training converge to better solutions.

Parameters

ndf : NeuralDecisionForestRegressor The forest instance whose trees will be annealed. start : float, default=2.0 Starting temperature (soft routing). end : float, default=0.5 Ending temperature (sharp routing). epochs : int, default=50 Total epochs over which to anneal. Should match fit() epochs.

Example

ndf = NeuralDecisionForestRegressor(temperature=2.0) annealer = TemperatureAnnealing(ndf, start=2.0, end=0.5, epochs=50) ndf.fit(X, y, epochs=50, callbacks=[annealer])

Source code in src/centimators/model_estimators/keras_estimators/tree.py
class TemperatureAnnealing(callbacks.Callback):
    """Anneal tree routing temperature from soft to sharp over training.

    Starts with high temperature (soft routing, samples flow through many paths)
    and linearly decreases to low temperature (sharp routing, more tree-like).
    This can theoretically help training converge to better solutions.

    Parameters
    ----------
    ndf : NeuralDecisionForestRegressor
        The forest instance whose trees will be annealed.
    start : float, default=2.0
        Starting temperature (soft routing).
    end : float, default=0.5
        Ending temperature (sharp routing).
    epochs : int, default=50
        Total epochs over which to anneal. Should match fit() epochs.

    Example
    -------
    >>> ndf = NeuralDecisionForestRegressor(temperature=2.0)
    >>> annealer = TemperatureAnnealing(ndf, start=2.0, end=0.5, epochs=50)
    >>> ndf.fit(X, y, epochs=50, callbacks=[annealer])
    """

    def __init__(self, ndf, start: float = 2.0, end: float = 0.5, epochs: int = 50):
        super().__init__()
        self.ndf = ndf
        self.start = start
        self.end = end
        self.epochs = epochs

    def on_epoch_end(self, epoch, logs=None):
        t = self.start - (self.start - self.end) * ((epoch + 1) / self.epochs)
        for tree in self.ndf.trees:
            tree.temperature.assign(t)