How to use the gala.classify.DefaultRandomForest function in gala

To help you get started, we’ve selected a few gala examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github janelia-flyem / gala / tests / _util / generate-test-results.py View on Github external
np.random.RandomState(0)
rf = rf.fit(X, y)
classify.save_classifier(rf, 'example-data/rf-1.joblib')
learned_policy = agglo.classifier_probability(fc, rf)
g_test = agglo.Rag(ws_test, pr_test, learned_policy, feature_manager=fc)
g_test.agglomerate(0.5)
seg_test1 = g_test.get_segmentation()
imio.write_h5_stack(seg_test1, 'example-data/test-seg1.lzf.h5', compression='lzf')
g_train4 = agglo.Rag(ws_train, p4_train, feature_manager=fc)
np.random.RandomState(0)
(X4, y4, w4, merges4) = map(np.copy, map(np.ascontiguousarray,
                            g_train4.learn_agglomerate(gt_train, fc)[0]))
print X4.shape
np.savez('example-data/train-set4.npz', X=X4, y=y4)
y4 = y4[:, 0]
rf4 = classify.DefaultRandomForest()
np.random.RandomState(0)
rf4 = rf4.fit(X4, y4)
classify.save_classifier(rf4, 'example-data/rf-4.joblib')
learned_policy4 = agglo.classifier_probability(fc, rf4)
g_test4 = agglo.Rag(ws_test, p4_test, learned_policy4, feature_manager=fc)
g_test4.agglomerate(0.5)
seg_test4 = g_test4.get_segmentation()
imio.write_h5_stack(seg_test4, 'example-data/test-seg4.lzf.h5', compression='lzf')

results = np.vstack((
    ev.split_vi(ws_test, gt_test),
    ev.split_vi(seg_test1, gt_test),
    ev.split_vi(seg_test4, gt_test)
    ))

np.save('example-data/vi-results.npy', results)
github janelia-flyem / gala / gala / serve.py View on Github external
def relearn(self):
        """Learn a new merge policy using data gathered so far.

        This resets the state of the RAG to contain only the merges and
        separations received over the course of its history.
        """
        clf = classify.DefaultRandomForest().fit(self.features, self.targets)
        self.policy = agglo.classifier_probability(self.feature_manager, clf)
        self.rag = self.original_rag.copy()
        self.rag.merge_priority_function = self.policy
        self.rag.rebuild_merge_queue()
        for i, (s0, s1) in enumerate(self.separate):
            self.rag.node[s0]['exclusions'].add(i)
            self.rag.node[s1]['exclusions'].add(i)
github janelia-flyem / gala / benchmarks / bench_gala.py View on Github external
def classifier():
    X, y = trexamples()
    rf = classify.DefaultRandomForest()
    rf.fit(X, y)
    return rf