How to use the ngboost.distns.ClassificationDistn function in ngboost

To help you get started, we’ve selected a few ngboost examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github stanfordmlgroup / ngboost / ngboost / distns / categorical.py View on Github external
def k_categorical(K):
    """
    Factory function that generates classes for K-class categorical distributions for NGBoost

    The generated distribution has two parameters, loc and scale, which are the mean and standard deviation, respectively.
    This distribution has both LogScore and CRPScore implemented for it.
    """

    class Categorical(ClassificationDistn):

        scores = [CategoricalLogScore]
        problem_type = "classification"
        n_params = K - 1
        K_ = K

        def __init__(self, params):
            super().__init__(params)
            _, N = params.shape
            self.logits = np.zeros((K, N))
            self.logits[1:K, :] = params  # default the 0th class logits to 0
            self.probs = sp.special.softmax(self.logits, axis=0)
            # self.dist = dist(n=1, p=self.probs) # scipy doesn't allow vectorized multinomial (!?!?) why allow vectorized versions of the others?
            # this makes me want to refactor all the other code to use lists of distributions, would be more readable imo

        def fit(Y):
github stanfordmlgroup / ngboost / ngboost / api.py View on Github external
self,
        Dist=Bernoulli,
        Score=LogScore,
        Base=default_tree_learner,
        natural_gradient=True,
        n_estimators=500,
        learning_rate=0.01,
        minibatch_frac=1.0,
        col_sample=1.0,
        verbose=True,
        verbose_eval=100,
        tol=1e-4,
        random_state=None,
    ):
        assert issubclass(
            Dist, ClassificationDistn
        ), f"{Dist.__name__} is not useable for classification."
        super().__init__(
            Dist,
            Score,
            Base,
            natural_gradient,
            n_estimators,
            learning_rate,
            minibatch_frac,
            col_sample,
            verbose,
            verbose_eval,
            tol,
            random_state,
        )