Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_single_condition():
estimator = ensemble.RandomForestRegressor(n_estimators=2, random_state=1)
estimator.fit([[1], [2]], [1, 2])
assembler = assemblers.RandomForestModelAssembler(estimator)
actual = assembler.assemble()
expected = ast.BinNumExpr(
ast.BinNumExpr(
ast.SubroutineExpr(
ast.NumVal(1.0)),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
def kernel_ast(sup_vec_value):
return ast.SubroutineExpr(
ast.BinNumExpr(
ast.NumVal(sup_vec_value),
ast.FeatureRef(0),
ast.BinNumOpType.MUL))
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0),
ast.NumVal(0.0)]),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0),
ast.NumVal(0.0)]),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(0.0),
ast.NumVal(1.0)])))
assert utils.cmp_exprs(actual, expected)
def _assemble_branch(self, node_id):
left_id = self._tree.children_left[node_id]
right_id = self._tree.children_right[node_id]
cond = self._assemble_cond(node_id)
return ast.IfExpr(cond,
self._assemble_node(left_id),
self._assemble_node(right_id))
def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs):
if not isinstance(expr, ast.BinExpr):
return None, kwargs
# We track depth of the binary expressions and call a hook if it
# reaches specified threshold .
if bin_depth == self.bin_depth_threshold:
return self.bin_depth_threshold_hook(expr, **kwargs), kwargs
kwargs["bin_depth"] = bin_depth + 1
return None, kwargs
"rbf": self._rbf_kernel,
"sigmoid": self._sigmoid_kernel,
"poly": self._poly_kernel,
"linear": self._linear_kernel
}
kernel_type = model.kernel
if kernel_type not in supported_kernels:
raise ValueError("Unsupported kernel type {}".format(kernel_type))
self._kernel_fun = supported_kernels[kernel_type]
n_features = len(model.support_vectors_[0])
gamma = model.gamma
if gamma == "auto" or gamma == "auto_deprecated":
gamma = 1.0 / n_features
self._gamma_expr = ast.NumVal(gamma)
self._neg_gamma_expr = utils.sub(ast.NumVal(0), ast.NumVal(gamma),
to_reuse=True)
self._output_size = 1
if type(model).__name__ in ("SVC", "NuSVC"):
n_classes = len(model.n_support_)
if n_classes > 2:
self._output_size = n_classes