How to use the m2cgen.ast function in m2cgen

To help you get started, we’ve selected a few m2cgen examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github BayesWitnesses / m2cgen / tests / assemblers / test_ensemble.py View on Github external
def test_single_condition():
    estimator = ensemble.RandomForestRegressor(n_estimators=2, random_state=1)

    estimator.fit([[1], [2]], [1, 2])

    assembler = assemblers.RandomForestModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.BinNumExpr(
        ast.BinNumExpr(
            ast.SubroutineExpr(
                ast.NumVal(1.0)),
            ast.NumVal(0.5),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.SubroutineExpr(
                ast.IfExpr(
                    ast.CompExpr(
                        ast.FeatureRef(0),
                        ast.NumVal(1.5),
                        ast.CompOpType.LTE),
                    ast.NumVal(1.0),
                    ast.NumVal(2.0))),
            ast.NumVal(0.5),
            ast.BinNumOpType.MUL),
github BayesWitnesses / m2cgen / tests / assemblers / test_svm.py View on Github external
def kernel_ast(sup_vec_value):
        return ast.SubroutineExpr(
            ast.BinNumExpr(
                ast.NumVal(sup_vec_value),
                ast.FeatureRef(0),
                ast.BinNumOpType.MUL))
github BayesWitnesses / m2cgen / tests / assemblers / test_tree.py View on Github external
ast.VectorVal([
            ast.NumVal(1.0),
            ast.NumVal(0.0),
            ast.NumVal(0.0)]),
        ast.IfExpr(
            ast.CompExpr(
                ast.FeatureRef(0),
                ast.NumVal(2.5),
                ast.CompOpType.LTE),
            ast.VectorVal([
                ast.NumVal(0.0),
                ast.NumVal(1.0),
                ast.NumVal(0.0)]),
            ast.VectorVal([
                ast.NumVal(0.0),
                ast.NumVal(0.0),
                ast.NumVal(1.0)])))

    assert utils.cmp_exprs(actual, expected)
github BayesWitnesses / m2cgen / m2cgen / assemblers / tree.py View on Github external
def _assemble_branch(self, node_id):
        left_id = self._tree.children_left[node_id]
        right_id = self._tree.children_right[node_id]
        cond = self._assemble_cond(node_id)
        return ast.IfExpr(cond,
                          self._assemble_node(left_id),
                          self._assemble_node(right_id))
github BayesWitnesses / m2cgen / m2cgen / interpreters / mixins.py View on Github external
def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs):
        if not isinstance(expr, ast.BinExpr):
            return None, kwargs

        # We track depth of the binary expressions and call a hook if it
        # reaches specified threshold .
        if bin_depth == self.bin_depth_threshold:
            return self.bin_depth_threshold_hook(expr, **kwargs), kwargs

        kwargs["bin_depth"] = bin_depth + 1
        return None, kwargs
github BayesWitnesses / m2cgen / m2cgen / assemblers / svm.py View on Github external
"rbf": self._rbf_kernel,
            "sigmoid": self._sigmoid_kernel,
            "poly": self._poly_kernel,
            "linear": self._linear_kernel
        }
        kernel_type = model.kernel
        if kernel_type not in supported_kernels:
            raise ValueError("Unsupported kernel type {}".format(kernel_type))
        self._kernel_fun = supported_kernels[kernel_type]

        n_features = len(model.support_vectors_[0])

        gamma = model.gamma
        if gamma == "auto" or gamma == "auto_deprecated":
            gamma = 1.0 / n_features
        self._gamma_expr = ast.NumVal(gamma)
        self._neg_gamma_expr = utils.sub(ast.NumVal(0), ast.NumVal(gamma),
                                         to_reuse=True)

        self._output_size = 1
        if type(model).__name__ in ("SVC", "NuSVC"):
            n_classes = len(model.n_support_)
            if n_classes > 2:
                self._output_size = n_classes