Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"name": transformer.prediction_column,
"port": "prediction"
})
outputs.append({
"name": "raw_prediction",
"port": "raw_prediction"
})
outputs.append({
"name": "probability",
"port": "probability"
})
# compile tuples of model attributes to serialize
tree_weights = Vector([1.0 for x in range(0, len(transformer.estimators_))])
attributes = list()
attributes.append(('num_features', transformer.n_features_))
attributes.append(('tree_weights', tree_weights))
attributes.append(('trees', ["tree{}".format(x) for x in range(0, len(transformer.estimators_))]))
if isinstance(transformer, RandomForestClassifier):
attributes.append(('num_classes', transformer.n_classes_)) # TODO: get number of classes from the transformer
self.serialize(transformer, path, model, attributes, inputs, outputs)
rf_path = "{}/{}.node".format(path, model)
estimators = transformer.estimators_
i = 0
for estimator in estimators:
estimator.mlinit(input_features = transformer.input_features, prediction_column = transformer.prediction_column, feature_names=transformer.feature_names)
def serialize_to_bundle(self, path, model_name):
# compile tuples of model attributes to serialize
attributes = list()
attributes.append(("labels", self.labels.keys()))
attributes.append(("values", Vector(self.labels.values())))
# define node inputs and outputs
inputs = [{
"name": self.input_features[0],
"port": "input"
}]
outputs = [{
"name": self.output_features,
"port": "output"
}]
self.serialize(self, path, model_name, attributes, inputs, outputs)