Skip to content

Model Estimators

centimators.model_estimators.keras_estimators

Keras-based model estimators with scikit-learn compatible API.

Organized by architectural family
  • base: BaseKerasEstimator and shared utilities
  • dense: Simple feedforward networks (MLPRegressor)
  • autoencoder: Reconstruction-based architectures (BottleneckEncoder)
  • sequence: Sequence models for temporal data (SequenceEstimator, LSTMRegressor)

BaseKerasEstimator dataclass

Bases: TransformerMixin, BaseEstimator, ABC

Meta-estimator for Keras models following the scikit-learn API.

Parameters:

Name Type Description Default
output_units int, default=1

Number of output units in the final layer.

1
optimizer Type[optimizers.Optimizer], default=Adam

Keras optimizer class to use for training.

Adam
learning_rate float, default=0.001

Learning rate for the optimizer.

0.001
loss_function str, default="mse"

Loss function name passed to model.compile().

'mse'
metrics list[str] | None, default=None

List of metric names to track during training.

None
model Any, default=None

The underlying Keras model (populated by build_model).

None
distribution_strategy str | None, default=None

If set, enables DataParallel distribution for multi-device training.

None
target_scaler sklearn transformer | None, default=None

Scaler for target values. Neural networks converge better when targets are normalized. Subclasses may override the default (e.g., regressors default to StandardScaler).

None
Source code in src/centimators/model_estimators/keras_estimators/base.py
@dataclass(kw_only=True)
class BaseKerasEstimator(TransformerMixin, BaseEstimator, ABC):
    """Meta-estimator for Keras models following the scikit-learn API.

    Args:
        output_units (int, default=1): Number of output units in the final layer.
        optimizer (Type[optimizers.Optimizer], default=Adam): Keras optimizer class
            to use for training.
        learning_rate (float, default=0.001): Learning rate for the optimizer.
        loss_function (str, default="mse"): Loss function name passed to model.compile().
        metrics (list[str] | None, default=None): List of metric names to track
            during training.
        model (Any, default=None): The underlying Keras model (populated by build_model).
        distribution_strategy (str | None, default=None): If set, enables DataParallel
            distribution for multi-device training.
        target_scaler (sklearn transformer | None, default=None): Scaler for target
            values. Neural networks converge better when targets are normalized.
            Subclasses may override the default (e.g., regressors default to StandardScaler).
    """

    output_units: int = 1
    optimizer: Type[optimizers.Optimizer] = optimizers.Adam
    learning_rate: float = 0.001
    loss_function: str = "mse"
    metrics: list[str] | None = None
    model: Any = None
    distribution_strategy: str | None = None
    target_scaler: Any = None

    @abstractmethod
    def build_model(self):
        pass

    def _setup_distribution_strategy(self) -> None:
        strategy = distribution.DataParallel()
        distribution.set_distribution(strategy)

    def fit(
        self,
        X,
        y,
        epochs: int = 100,
        batch_size: int = 32,
        validation_data: tuple[Any, Any] | None = None,
        callbacks: list[Any] | None = None,
        **kwargs: Any,
    ) -> "BaseKerasEstimator":
        self._n_features_in_ = X.shape[1]

        if self.distribution_strategy:
            self._setup_distribution_strategy()

        # Convert inputs to numpy
        X_np = _ensure_numpy(X)
        y_np = _ensure_numpy(y, allow_series=True)

        # Ensure y is 2D for scaler
        y_was_1d = y_np.ndim == 1
        if y_was_1d:
            y_np = y_np.reshape(-1, 1)

        # Scale targets for better neural network convergence
        if self.target_scaler:
            y_np = self.target_scaler.fit_transform(y_np).astype("float32")

            # Scale validation targets too
            if validation_data is not None:
                val_X, val_y = validation_data
                val_y_np = _ensure_numpy(val_y, allow_series=True)
                if val_y_np.ndim == 1:
                    val_y_np = val_y_np.reshape(-1, 1)
                val_y_scaled = self.target_scaler.transform(val_y_np).astype("float32")
                validation_data = (_ensure_numpy(val_X), val_y_scaled)

        if not self.model:
            self.build_model()

        self.model.fit(
            X_np,
            y=y_np,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=validation_data,
            callbacks=callbacks,
            **kwargs,
        )
        self._is_fitted = True
        return self

    @nw.narwhalify
    def predict(self, X, batch_size: int = 512, **kwargs: Any) -> Any:
        if not self.model:
            raise ValueError("Model not built. Call `build_model` first.")

        predictions = self.model.predict(
            _ensure_numpy(X), batch_size=batch_size, **kwargs
        )

        # Inverse transform predictions back to original scale
        if self.target_scaler:
            predictions = self.target_scaler.inverse_transform(predictions)

        # Return numpy arrays for numpy input
        if isinstance(X, numpy.ndarray):
            return predictions

        # Return dataframe for dataframe input
        if predictions.ndim == 1:
            return nw.from_dict(
                {"prediction": predictions}, backend=nw.get_native_namespace(X)
            )
        elif predictions.shape[1] == 1:
            return nw.from_dict(
                {"prediction": predictions[:, 0]}, backend=nw.get_native_namespace(X)
            )
        else:
            cols = {
                f"prediction_{i}": predictions[:, i]
                for i in range(predictions.shape[1])
            }
            return nw.from_dict(cols, backend=nw.get_native_namespace(X))

    def transform(self, X, **kwargs):
        return self.predict(X, **kwargs)

    def __sklearn_is_fitted__(self) -> bool:
        return getattr(self, "_is_fitted", False)

MLPRegressor dataclass

Bases: RegressorMixin, BaseKerasEstimator

A minimal fully-connected multi-layer perceptron for tabular data.

Source code in src/centimators/model_estimators/keras_estimators/dense.py
@dataclass(kw_only=True)
class MLPRegressor(RegressorMixin, BaseKerasEstimator):
    """A minimal fully-connected multi-layer perceptron for tabular data."""

    hidden_units: tuple[int, ...] = (64, 64)
    activation: str = "relu"
    dropout_rate: float = 0.0
    metrics: list[str] | None = field(default_factory=lambda: ["mse"])
    target_scaler: Any = field(default_factory=StandardScaler)

    def build_model(self):
        inputs = layers.Input(shape=(self._n_features_in_,), name="features")
        x = inputs
        for units in self.hidden_units:
            x = layers.Dense(units, activation=self.activation)(x)
            if self.dropout_rate > 0:
                x = layers.Dropout(self.dropout_rate)(x)
        outputs = layers.Dense(self.output_units, activation="linear")(x)
        self.model = models.Model(inputs=inputs, outputs=outputs, name="mlp_regressor")

        self.model.compile(
            optimizer=self.optimizer(learning_rate=self.learning_rate),
            loss=self.loss_function,
            metrics=self.metrics,
        )
        return self

BottleneckEncoder dataclass

Bases: BaseKerasEstimator

A bottleneck autoencoder that can learn latent representations and predict targets.

Source code in src/centimators/model_estimators/keras_estimators/autoencoder.py
@dataclass(kw_only=True)
class BottleneckEncoder(BaseKerasEstimator):
    """A bottleneck autoencoder that can learn latent representations and predict targets."""

    gaussian_noise: float = 0.035
    encoder_units: list[tuple[int, float]] = field(
        default_factory=lambda: [(1024, 0.1)]
    )
    latent_units: tuple[int, float] = (256, 0.1)
    ae_units: list[tuple[int, float]] = field(default_factory=lambda: [(96, 0.4)])
    activation: str = "swish"
    reconstruction_loss_weight: float = 1.0
    target_loss_weight: float = 1.0
    encoder: Any = None

    def build_model(self):
        if self._n_features_in_ is None:
            raise ValueError("Must call fit() before building the model")

        inputs = layers.Input(shape=(self._n_features_in_,), name="features")
        x0 = layers.BatchNormalization()(inputs)

        encoder = layers.GaussianNoise(self.gaussian_noise)(x0)
        for units, dropout in self.encoder_units:
            encoder = layers.Dense(units)(encoder)
            encoder = layers.BatchNormalization()(encoder)
            encoder = layers.Activation(self.activation)(encoder)
            encoder = layers.Dropout(dropout)(encoder)

        latent_units, latent_dropout = self.latent_units
        latent = layers.Dense(latent_units)(encoder)
        latent = layers.BatchNormalization()(latent)
        latent = layers.Activation(self.activation)(latent)
        latent_output = layers.Dropout(latent_dropout)(latent)

        self.encoder = models.Model(
            inputs=inputs, outputs=latent_output, name="encoder"
        )

        decoder = latent_output
        for units, dropout in reversed(self.encoder_units):
            decoder = layers.Dense(units)(decoder)
            decoder = layers.BatchNormalization()(decoder)
            decoder = layers.Activation(self.activation)(decoder)
            decoder = layers.Dropout(dropout)(decoder)

        reconstruction = layers.Dense(self._n_features_in_, name="reconstruction")(
            decoder
        )

        target_pred = reconstruction
        for units, dropout in self.ae_units:
            target_pred = layers.Dense(units)(target_pred)
            target_pred = layers.BatchNormalization()(target_pred)
            target_pred = layers.Activation(self.activation)(target_pred)
            target_pred = layers.Dropout(dropout)(target_pred)

        target_output = layers.Dense(
            self.output_units, activation="linear", name="target_prediction"
        )(target_pred)

        self.model = models.Model(
            inputs=inputs,
            outputs=[reconstruction, target_output],
            name="bottleneck_encoder",
        )

        self.model.compile(
            optimizer=self.optimizer(learning_rate=self.learning_rate),
            loss={"reconstruction": "mse", "target_prediction": self.loss_function},
            loss_weights={
                "reconstruction": self.reconstruction_loss_weight,
                "target_prediction": self.target_loss_weight,
            },
            metrics={"target_prediction": self.metrics or ["mse"]},
        )
        return self

    def fit(
        self,
        X,
        y,
        epochs: int = 100,
        batch_size: int = 32,
        validation_data: tuple[Any, Any] | None = None,
        callbacks: list[Any] | None = None,
        **kwargs: Any,
    ) -> "BottleneckEncoder":
        self._n_features_in_ = X.shape[1]

        if self.distribution_strategy:
            self._setup_distribution_strategy()

        if not self.model:
            self.build_model()

        X_np = _ensure_numpy(X)
        y_np = _ensure_numpy(y, allow_series=True)

        y_dict = {"reconstruction": X_np, "target_prediction": y_np}

        if validation_data is not None:
            X_val, y_val = validation_data
            X_val_np = _ensure_numpy(X_val)
            y_val_np = _ensure_numpy(y_val, allow_series=True)
            validation_data = (
                X_val_np,
                {"reconstruction": X_val_np, "target_prediction": y_val_np},
            )

        self.model.fit(
            X_np,
            y_dict,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=validation_data,
            callbacks=callbacks,
            **kwargs,
        )

        self._is_fitted = True
        return self

    def predict(self, X, batch_size: int = 512, **kwargs: Any) -> Any:
        if not self.model:
            raise ValueError("Model not built. Call 'fit' first.")
        X_np = _ensure_numpy(X)
        predictions = self.model.predict(X_np, batch_size=batch_size, **kwargs)
        return predictions[1] if isinstance(predictions, list) else predictions

    def transform(self, X, batch_size: int = 512, **kwargs: Any) -> Any:
        if not self.encoder:
            raise ValueError("Encoder not built. Call 'fit' first.")
        X_np = _ensure_numpy(X)
        return self.encoder.predict(X_np, batch_size=batch_size, **kwargs)

    def fit_transform(self, X, y, **kwargs) -> Any:
        return self.fit(X, y, **kwargs).transform(X)

    def get_feature_names_out(self, input_features=None) -> list[str]:
        latent_dim = self.latent_units[0]
        return [f"latent_{i}" for i in range(latent_dim)]

SequenceEstimator dataclass

Bases: BaseKerasEstimator

Estimator for models that consume sequential data.

Source code in src/centimators/model_estimators/keras_estimators/sequence.py
@dataclass(kw_only=True)
class SequenceEstimator(BaseKerasEstimator):
    """Estimator for models that consume sequential data."""

    lag_windows: list[int]
    n_features_per_timestep: int

    def __post_init__(self):
        self.seq_length = len(self.lag_windows)

    def _reshape(self, X: IntoFrame, validation_data: tuple[Any, Any] | None = None):
        X = _ensure_numpy(X)
        X_reshaped = ops.reshape(
            X, (X.shape[0], self.seq_length, self.n_features_per_timestep)
        )

        if validation_data:
            X_val, y_val = validation_data
            X_val = _ensure_numpy(X_val)
            X_val_reshaped = ops.reshape(
                X_val,
                (X_val.shape[0], self.seq_length, self.n_features_per_timestep),
            )
            validation_data = X_val_reshaped, _ensure_numpy(y_val)

        return X_reshaped, validation_data

    def fit(
        self, X, y, validation_data: tuple[Any, Any] | None = None, **kwargs: Any
    ) -> "SequenceEstimator":
        X_reshaped, validation_data_reshaped = self._reshape(X, validation_data)
        super().fit(
            X_reshaped,
            y=_ensure_numpy(y),
            validation_data=validation_data_reshaped,
            **kwargs,
        )
        return self

    @nw.narwhalify
    def predict(self, X, batch_size: int = 512, **kwargs: Any) -> Any:
        if not self.model:
            raise ValueError("Model not built. Call `build_model` first.")

        # Store original X for backend detection before reshaping
        X_original = X
        X_reshaped, _ = self._reshape(X)

        predictions = self.model.predict(
            _ensure_numpy(X_reshaped), batch_size=batch_size, **kwargs
        )

        # Inverse transform predictions back to original scale
        if self.target_scaler:
            predictions = self.target_scaler.inverse_transform(predictions)

        # Use X_original (not X_reshaped) for backend detection
        if isinstance(X_original, numpy.ndarray):
            return predictions

        if predictions.ndim == 1:
            return nw.from_dict(
                {"prediction": predictions}, backend=nw.get_native_namespace(X_original)
            )
        else:
            cols = {
                f"prediction_{i}": predictions[:, i]
                for i in range(predictions.shape[1])
            }
            return nw.from_dict(cols, backend=nw.get_native_namespace(X_original))

LSTMRegressor dataclass

Bases: RegressorMixin, SequenceEstimator

LSTM-based regressor for sequence prediction.

Source code in src/centimators/model_estimators/keras_estimators/sequence.py
@dataclass(kw_only=True)
class LSTMRegressor(RegressorMixin, SequenceEstimator):
    """LSTM-based regressor for sequence prediction."""

    lstm_units: list[tuple[int, float, float]] = field(
        default_factory=lambda: [(64, 0.01, 0.01)]
    )
    use_batch_norm: bool = False
    use_layer_norm: bool = False
    bidirectional: bool = False
    metrics: list[str] | None = field(default_factory=lambda: ["mse"])
    target_scaler: Any = field(default_factory=StandardScaler)

    def build_model(self):
        if self._n_features_in_ is None:
            raise ValueError("Must call fit() before building the model")

        inputs = layers.Input(
            shape=(self.seq_length, self.n_features_per_timestep), name="sequence_input"
        )
        x = inputs

        for layer_num, (units, dropout, recurrent_dropout) in enumerate(
            self.lstm_units
        ):
            return_sequences = layer_num < len(self.lstm_units) - 1
            lstm_layer = layers.LSTM(
                units=units,
                activation="tanh",
                return_sequences=return_sequences,
                dropout=dropout,
                recurrent_dropout=recurrent_dropout,
                name=f"lstm_{layer_num}",
            )
            if self.bidirectional:
                x = layers.Bidirectional(lstm_layer, name=f"bidirectional_{layer_num}")(
                    x
                )
            else:
                x = lstm_layer(x)
            if self.use_layer_norm:
                x = layers.LayerNormalization(name=f"layer_norm_{layer_num}")(x)
            if self.use_batch_norm:
                x = layers.BatchNormalization(name=f"batch_norm_{layer_num}")(x)

        outputs = layers.Dense(self.output_units, activation="linear", name="output")(x)
        self.model = models.Model(inputs=inputs, outputs=outputs, name="lstm_regressor")
        self.model.compile(
            optimizer=self.optimizer(learning_rate=self.learning_rate),
            loss=self.loss_function,
            metrics=self.metrics,
        )
        return self

NeuralDecisionForestRegressor dataclass

Bases: RegressorMixin, BaseKerasEstimator

Neural Decision Forest regressor with differentiable tree ensembles.

A Neural Decision Forest is an ensemble of differentiable decision trees trained end-to-end via gradient descent. Each tree uses stochastic routing where internal nodes learn probability distributions over routing decisions. The forest combines predictions by averaging over all trees.

This architecture provides:

  • Interpretable tree-like structure with learned routing
  • Feature bagging via used_features_rate (like random forests)
  • End-to-end differentiable training
  • Ensemble averaging for improved generalization
  • Temperature-controlled routing sharpness
  • Input noise, per-tree noise, and tree dropout for ensemble diversity

Parameters:

Name Type Description Default
num_trees int, default=25

Number of decision trees in the forest ensemble.

25
depth int, default=4

Depth of each tree. Each tree will have 2^depth leaf nodes. Deeper trees have more capacity but harder gradient flow.

4
used_features_rate float, default=0.5

Fraction of features each tree randomly selects (0 to 1). Provides feature bagging. Lower values increase diversity.

0.5
l2_decision float, default=1e-4

L2 regularization for routing decision layers. Lower values allow sharper routing decisions.

0.0001
l2_leaf float, default=1e-3

L2 regularization for leaf output weights. Can be stronger than l2_decision since leaves are regression weights.

0.001
temperature float, default=0.5

Temperature for sigmoid sharpness in routing. Lower values (0.3-0.5) give sharper, more tree-like routing. Higher values (1-3) give softer routing where samples flow through multiple paths.

0.5
input_noise_std float, default=0.0

Gaussian noise std applied to inputs before trunk. Makes trunk robust to input perturbations. Try 0.02-0.05.

0.0
tree_noise_std float, default=0.0

Gaussian noise std applied per-tree after trunk. Each tree sees a different noisy view, decorrelating the ensemble. Try 0.03-0.1.

0.0
tree_dropout_rate float, default=0.0

Dropout rate for tree outputs during training (0 to 1). Randomly drops tree contributions to decorrelate ensemble.

0.0
trunk_units list[int] | None, default=None

Hidden layer sizes for optional shared MLP trunk before trees. E.g. [64, 64] adds two Dense+ReLU layers. Trees then split on learned features instead of raw columns.

None
random_state int | None, default=None

Random seed for reproducible feature mask sampling across trees.

None
output_units int, default=1

Number of output targets to predict.

1
optimizer Type[keras.optimizers.Optimizer], default=Adam

Keras optimizer class to use for training.

Adam
learning_rate float, default=0.001

Learning rate for the optimizer.

0.001
loss_function str, default="mse"

Loss function for training.

'mse'
metrics list[str] | None, default=None

List of metrics to track during training.

(lambda: ['mse'])()
distribution_strategy str | None, default=None

Distribution strategy for multi-device training.

None

Attributes:

Name Type Description
model Model

The compiled Keras model containing the ensemble of trees.

trees list[NeuralDecisionTree]

List of tree models in the ensemble.

Examples:

>>> from centimators.model_estimators import NeuralDecisionForestRegressor
>>> import numpy as np
>>> X = np.random.randn(100, 10).astype('float32')
>>> y = np.random.randn(100, 1).astype('float32')
>>> ndf = NeuralDecisionForestRegressor(num_trees=5, depth=4)
>>> ndf.fit(X, y, epochs=10, verbose=0)
>>> predictions = ndf.predict(X)
Note
  • Larger depth increases model capacity but may lead to overfitting
  • More trees generally improve performance but increase computation
  • Lower used_features_rate increases diversity but may hurt individual tree performance
  • Works well on tabular data where tree-based methods traditionally excel
  • Lower temperature (0.3-0.5) gives sharper, more tree-like routing

The approach is based on Neural Decision Forests and related differentiable tree architectures that enable end-to-end learning of routing decisions.

Source code in src/centimators/model_estimators/keras_estimators/tree.py
@dataclass(kw_only=True)
class NeuralDecisionForestRegressor(RegressorMixin, BaseKerasEstimator):
    """Neural Decision Forest regressor with differentiable tree ensembles.

    A Neural Decision Forest is an ensemble of differentiable decision trees
    trained end-to-end via gradient descent. Each tree uses stochastic routing
    where internal nodes learn probability distributions over routing decisions.
    The forest combines predictions by averaging over all trees.

    This architecture provides:

    - Interpretable tree-like structure with learned routing
    - Feature bagging via used_features_rate (like random forests)
    - End-to-end differentiable training
    - Ensemble averaging for improved generalization
    - Temperature-controlled routing sharpness
    - Input noise, per-tree noise, and tree dropout for ensemble diversity

    Args:
        num_trees (int, default=25): Number of decision trees in the forest ensemble.
        depth (int, default=4): Depth of each tree. Each tree will have 2^depth leaf nodes.
            Deeper trees have more capacity but harder gradient flow.
        used_features_rate (float, default=0.5): Fraction of features each tree randomly
            selects (0 to 1). Provides feature bagging. Lower values increase diversity.
        l2_decision (float, default=1e-4): L2 regularization for routing decision layers.
            Lower values allow sharper routing decisions.
        l2_leaf (float, default=1e-3): L2 regularization for leaf output weights.
            Can be stronger than l2_decision since leaves are regression weights.
        temperature (float, default=0.5): Temperature for sigmoid sharpness in routing.
            Lower values (0.3-0.5) give sharper, more tree-like routing. Higher values
            (1-3) give softer routing where samples flow through multiple paths.
        input_noise_std (float, default=0.0): Gaussian noise std applied to inputs
            before trunk. Makes trunk robust to input perturbations. Try 0.02-0.05.
        tree_noise_std (float, default=0.0): Gaussian noise std applied per-tree after
            trunk. Each tree sees a different noisy view, decorrelating the ensemble.
            Try 0.03-0.1.
        tree_dropout_rate (float, default=0.0): Dropout rate for tree outputs during
            training (0 to 1). Randomly drops tree contributions to decorrelate ensemble.
        trunk_units (list[int] | None, default=None): Hidden layer sizes for optional
            shared MLP trunk before trees. E.g. [64, 64] adds two Dense+ReLU layers.
            Trees then split on learned features instead of raw columns.
        random_state (int | None, default=None): Random seed for reproducible feature
            mask sampling across trees.
        output_units (int, default=1): Number of output targets to predict.
        optimizer (Type[keras.optimizers.Optimizer], default=Adam): Keras optimizer
            class to use for training.
        learning_rate (float, default=0.001): Learning rate for the optimizer.
        loss_function (str, default="mse"): Loss function for training.
        metrics (list[str] | None, default=None): List of metrics to track during training.
        distribution_strategy (str | None, default=None): Distribution strategy for
            multi-device training.

    Attributes:
        model (keras.Model): The compiled Keras model containing the ensemble of trees.
        trees (list[NeuralDecisionTree]): List of tree models in the ensemble.

    Examples:
        >>> from centimators.model_estimators import NeuralDecisionForestRegressor
        >>> import numpy as np
        >>> X = np.random.randn(100, 10).astype('float32')
        >>> y = np.random.randn(100, 1).astype('float32')
        >>> ndf = NeuralDecisionForestRegressor(num_trees=5, depth=4)
        >>> ndf.fit(X, y, epochs=10, verbose=0)
        >>> predictions = ndf.predict(X)

    Note:
        - Larger depth increases model capacity but may lead to overfitting
        - More trees generally improve performance but increase computation
        - Lower used_features_rate increases diversity but may hurt individual tree performance
        - Works well on tabular data where tree-based methods traditionally excel
        - Lower temperature (0.3-0.5) gives sharper, more tree-like routing

        The approach is based on Neural Decision Forests and related differentiable
        tree architectures that enable end-to-end learning of routing decisions.
    """

    num_trees: int = 25
    depth: int = 4
    used_features_rate: float = 0.5
    l2_decision: float = 1e-4
    l2_leaf: float = 1e-3
    temperature: float = 0.5
    input_noise_std: float = 0.0
    tree_noise_std: float = 0.0
    tree_dropout_rate: float = 0.0
    trunk_units: list[int] | None = None
    random_state: int | None = None
    metrics: list[str] | None = field(default_factory=lambda: ["mse"])
    target_scaler: Any = field(default_factory=StandardScaler)

    def __post_init__(self):
        self.trees: list[NeuralDecisionTree] = []

    def build_model(self):
        """Build the neural decision forest model.

        Creates an ensemble of NeuralDecisionTree models with shared input
        and averaged output. Each tree receives normalized input features
        via BatchNormalization. Optionally includes input noise (before trunk
        for robustness), per-tree noise (for diversity), tree dropout, and
        a shared MLP trunk.

        Returns:
            self: Returns self for method chaining.
        """
        if self.model is None:
            if self.distribution_strategy:
                self._setup_distribution_strategy()

            # Set up RNG for reproducibility
            rng = np.random.default_rng(self.random_state)

            # Input layer
            inputs = layers.Input(shape=(self._n_features_in_,))
            x = layers.BatchNormalization()(inputs)

            # Input noise before trunk (makes trunk robust to perturbations)
            if self.input_noise_std > 0:
                x = layers.GaussianNoise(self.input_noise_std)(x)

            # Optional shared trunk (MLP before trees)
            if self.trunk_units:
                for units in self.trunk_units:
                    x = layers.Dense(units, activation="relu")(x)

            # Determine feature count for trees (trunk output or raw features)
            tree_num_features = (
                self.trunk_units[-1] if self.trunk_units else self._n_features_in_
            )

            # Create ensemble of trees
            self.trees = []
            for _ in range(self.num_trees):
                tree = NeuralDecisionTree(
                    depth=self.depth,
                    num_features=tree_num_features,
                    used_features_rate=self.used_features_rate,
                    output_units=self.output_units,
                    l2_decision=self.l2_decision,
                    l2_leaf=self.l2_leaf,
                    temperature=self.temperature,
                    rng=rng,
                )
                self.trees.append(tree)

            # each tree gets its own noisy view for diversity
            tree_outputs = []
            for tree in self.trees:
                if self.tree_noise_std > 0:
                    noisy_x = layers.GaussianNoise(self.tree_noise_std)(x)
                    tree_outputs.append(tree(noisy_x))
                else:
                    tree_outputs.append(tree(x))

            if len(tree_outputs) > 1:
                stacked = K.stack(tree_outputs, axis=1)  # [batch, num_trees, out_units]
                if self.tree_dropout_rate > 0:
                    # Drop entire trees
                    stacked = layers.Dropout(
                        self.tree_dropout_rate,
                        noise_shape=(
                            None,
                            self.num_trees,
                            1,
                        ),  # broadcasts so whole tree is dropped
                    )(stacked)
                outputs = K.mean(stacked, axis=1)
            else:
                outputs = tree_outputs[0]

            self.model = models.Model(inputs=inputs, outputs=outputs)
            opt = self.optimizer(learning_rate=self.learning_rate)
            self.model.compile(
                optimizer=opt, loss=self.loss_function, metrics=self.metrics
            )
        return self

build_model()

Build the neural decision forest model.

Creates an ensemble of NeuralDecisionTree models with shared input and averaged output. Each tree receives normalized input features via BatchNormalization. Optionally includes input noise (before trunk for robustness), per-tree noise (for diversity), tree dropout, and a shared MLP trunk.

Returns:

Name Type Description
self

Returns self for method chaining.

Source code in src/centimators/model_estimators/keras_estimators/tree.py
def build_model(self):
    """Build the neural decision forest model.

    Creates an ensemble of NeuralDecisionTree models with shared input
    and averaged output. Each tree receives normalized input features
    via BatchNormalization. Optionally includes input noise (before trunk
    for robustness), per-tree noise (for diversity), tree dropout, and
    a shared MLP trunk.

    Returns:
        self: Returns self for method chaining.
    """
    if self.model is None:
        if self.distribution_strategy:
            self._setup_distribution_strategy()

        # Set up RNG for reproducibility
        rng = np.random.default_rng(self.random_state)

        # Input layer
        inputs = layers.Input(shape=(self._n_features_in_,))
        x = layers.BatchNormalization()(inputs)

        # Input noise before trunk (makes trunk robust to perturbations)
        if self.input_noise_std > 0:
            x = layers.GaussianNoise(self.input_noise_std)(x)

        # Optional shared trunk (MLP before trees)
        if self.trunk_units:
            for units in self.trunk_units:
                x = layers.Dense(units, activation="relu")(x)

        # Determine feature count for trees (trunk output or raw features)
        tree_num_features = (
            self.trunk_units[-1] if self.trunk_units else self._n_features_in_
        )

        # Create ensemble of trees
        self.trees = []
        for _ in range(self.num_trees):
            tree = NeuralDecisionTree(
                depth=self.depth,
                num_features=tree_num_features,
                used_features_rate=self.used_features_rate,
                output_units=self.output_units,
                l2_decision=self.l2_decision,
                l2_leaf=self.l2_leaf,
                temperature=self.temperature,
                rng=rng,
            )
            self.trees.append(tree)

        # each tree gets its own noisy view for diversity
        tree_outputs = []
        for tree in self.trees:
            if self.tree_noise_std > 0:
                noisy_x = layers.GaussianNoise(self.tree_noise_std)(x)
                tree_outputs.append(tree(noisy_x))
            else:
                tree_outputs.append(tree(x))

        if len(tree_outputs) > 1:
            stacked = K.stack(tree_outputs, axis=1)  # [batch, num_trees, out_units]
            if self.tree_dropout_rate > 0:
                # Drop entire trees
                stacked = layers.Dropout(
                    self.tree_dropout_rate,
                    noise_shape=(
                        None,
                        self.num_trees,
                        1,
                    ),  # broadcasts so whole tree is dropped
                )(stacked)
            outputs = K.mean(stacked, axis=1)
        else:
            outputs = tree_outputs[0]

        self.model = models.Model(inputs=inputs, outputs=outputs)
        opt = self.optimizer(learning_rate=self.learning_rate)
        self.model.compile(
            optimizer=opt, loss=self.loss_function, metrics=self.metrics
        )
    return self

TemperatureAnnealing

Bases: Callback

Anneal tree routing temperature from soft to sharp over training.

Starts with high temperature (soft routing, samples flow through many paths) and linearly decreases to low temperature (sharp routing, more tree-like). This can theoretically help training converge to better solutions.

Parameters:

Name Type Description Default
ndf NeuralDecisionForestRegressor

The forest instance whose trees will be annealed.

required
start float, default=2.0

Starting temperature (soft routing).

2.0
end float, default=0.5

Ending temperature (sharp routing).

0.5
epochs int, default=50

Total epochs over which to anneal. Should match fit() epochs.

50

Examples:

>>> ndf = NeuralDecisionForestRegressor(temperature=2.0)
>>> annealer = TemperatureAnnealing(ndf, start=2.0, end=0.5, epochs=50)
>>> ndf.fit(X, y, epochs=50, callbacks=[annealer])
Source code in src/centimators/model_estimators/keras_estimators/tree.py
class TemperatureAnnealing(callbacks.Callback):
    """Anneal tree routing temperature from soft to sharp over training.

    Starts with high temperature (soft routing, samples flow through many paths)
    and linearly decreases to low temperature (sharp routing, more tree-like).
    This can theoretically help training converge to better solutions.

    Args:
        ndf (NeuralDecisionForestRegressor): The forest instance whose trees will be annealed.
        start (float, default=2.0): Starting temperature (soft routing).
        end (float, default=0.5): Ending temperature (sharp routing).
        epochs (int, default=50): Total epochs over which to anneal. Should match fit() epochs.

    Examples:
        >>> ndf = NeuralDecisionForestRegressor(temperature=2.0)
        >>> annealer = TemperatureAnnealing(ndf, start=2.0, end=0.5, epochs=50)
        >>> ndf.fit(X, y, epochs=50, callbacks=[annealer])
    """

    def __init__(self, ndf, start: float = 2.0, end: float = 0.5, epochs: int = 50):
        super().__init__()
        self.ndf = ndf
        self.start = start
        self.end = end
        self.epochs = epochs

    def on_epoch_end(self, epoch, logs=None):
        t = self.start - (self.start - self.end) * ((epoch + 1) / self.epochs)
        for tree in self.ndf.trees:
            tree.temperature.assign(t)