How to use the gala.features function in gala

To help you get started, we’ve selected a few gala examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github janelia-flyem / gala / tests / example-data / example.py View on Github external
# imports
from gala import imio, classify, features, agglo, evaluate as ev

# read in training data
gt_train, pr_train, ws_train = (map(imio.read_h5_stack,
                                ['train-gt.lzf.h5', 'train-p1.lzf.h5',
                                 'train-ws.lzf.h5']))

# create a feature manager
fm = features.moments.Manager()
fh = features.histogram.Manager()
fc = features.base.Composite(children=[fm, fh])

# create graph and obtain a training dataset
g_train = agglo.Rag(ws_train, pr_train, feature_manager=fc)
(X, y, w, merges) = g_train.learn_agglomerate(gt_train, fc)[0]
y = y[:, 0] # gala has 3 truth labeling schemes, pick the first one
print((X.shape, y.shape)) # standard scikit-learn input format

# train a classifier, scikit-learn syntax
rf = classify.DefaultRandomForest().fit(X, y)
# a policy is the composition of a feature map and a classifier
learned_policy = agglo.classifier_probability(fc, rf)

# get the test data and make a RAG with the trained policy
pr_test, ws_test = (map(imio.read_h5_stack,
github janelia-flyem / gala / tests / test_features.py View on Github external
"""
    if type(a1) == list and type(a2) == list:
        [assert_equal_lists_or_arrays(i1, i2, eps) for i1, i2 in zip(a1,a2)]
    elif type(a1) == np.ndarray and type(a2) == np.ndarray:
        assert_allclose(a1, a2, atol=eps)
    elif type(a1) == float and type(a2) == float:
        assert_approx_equal(a1, a2, int(-np.log10(eps)))
    else:
        assert_equal(a1, a2)


probs2 = np.load(os.path.join(rundir, 'toy-data/test-04-probabilities.npy'))
probs1 = probs2[..., 0]
wss1 = np.loadtxt(os.path.join(rundir, 'toy-data/test-04-watershed.txt'),
                  np.uint32)
f1, f2, f3 = (features.moments.Manager(2, False),
              features.histogram.Manager(3, compute_percentiles=[0.5]),
              features.squiggliness.Manager(ndim=2))
f4 = features.base.Composite(children=[f1, f2, f3])


def run_matched(f, fn, c=1,
                edges=[(1, 2), (6, 3), (7, 4)],
                merges=[(1, 2), (6, 3)]):
    p = probs1 if c == 1 else probs2
    g = agglo.Rag(wss1, p, feature_manager=f, use_slow=True)
    o = list_of_feature_arrays(g, f, edges, merges)
    with open(fn, 'rb') as fin:
        r = pck.load(fin, encoding='bytes')
    assert_equal_lists_or_arrays(o, r)

github janelia-flyem / gala / tests / test_gala.py View on Github external
rundir = os.path.dirname(__file__)

# load example data

train_list = ['example-data/train-gt.lzf.h5', 'example-data/train-p1.lzf.h5',
              'example-data/train-p4.lzf.h5', 'example-data/train-ws.lzf.h5']
train_list = [os.path.join(rundir, fn) for fn in train_list]
gt_train, pr_train, p4_train, ws_train = map(imio.read_h5_stack, train_list)
test_list = ['example-data/test-gt.lzf.h5', 'example-data/test-p1.lzf.h5',
             'example-data/test-p4.lzf.h5', 'example-data/test-ws.lzf.h5']
test_list = [os.path.join(rundir, fn) for fn in test_list]
gt_test, pr_test, p4_test, ws_test = map(imio.read_h5_stack, test_list)

# prepare feature manager
fm = features.moments.Manager()
fh = features.histogram.Manager()
fc = features.base.Composite(children=[fm, fh])

### helper functions


def load_pickle(fn):
    with open(fn, 'rb') as fin:
        if PYTHON_VERSION == 3:
            return pickle.load(fin, encoding='bytes', fix_imports=True)
        else:  # Python 2
            return pickle.load(fin)


def load_training_data(fn):
    io = np.load(fn)
github janelia-flyem / gala / tests / test_server.py View on Github external
def dummy_data2(dummy_data):
    frag, gt, _ = dummy_data
    frag[7, 7:9] = 17
    frag[7:10, -1] = 18
    fman = features.base.Mock(frag, gt)
    return frag, gt, fman
github janelia-flyem / gala / gala / test_package.py View on Github external
def testAggloRFBuild(self):
        from gala import agglo
        from gala import features
        from gala import classify
        self.datadir = os.path.abspath(os.path.dirname(sys.modules["gala"].__file__)) + "/testdata/"

        cl = classify.load_classifier(self.datadir + "agglomclassifier.rf.h5")
        fm_info = json.loads(str(cl.feature_description))
        fm = features.io.create_fm(fm_info)
        mpf = agglo.classifier_probability(fm, cl)

        watershed, dummy, prediction = self.gen_watershed()
        stack = agglo.Rag(watershed, prediction, mpf, feature_manager=fm, nozeros=True)
        self.assertEqual(stack.number_of_nodes(), 3630)
        stack.agglomerate(0.1)
        self.assertEqual(stack.number_of_nodes(), 88)
        stack.remove_inclusions()
        self.assertEqual(stack.number_of_nodes(), 86)
github janelia-flyem / gala / tests / test_features.py View on Github external
def test_convex_hull():
    ws = np.array([[1, 2, 2],
                   [1, 1, 2],
                   [1, 2, 2]], dtype=np.uint8)
    chull = features.convex_hull.Manager()
    g = agglo.Rag(ws, feature_manager=chull, use_slow=True)
    expected = np.array([0.5, 0.125, 0.5, 0.1, 1., 0.167, 0.025, 0.069,
                         0.44, 0.056, 1.25, 1.5, 1.2, 0.667])
    assert_allclose(chull(g, 1, 2), expected, atol=0.01, rtol=1.)
github janelia-flyem / gala / gala / test_package.py View on Github external
def testAggloRFBuild(self):
        from gala import agglo
        from gala import features
        from gala import classify
        self.datadir = os.path.abspath(os.path.dirname(sys.modules["gala"].__file__)) + "/testdata/"

        cl = classify.load_classifier(self.datadir + "agglomclassifier.rf.h5")
        fm_info = json.loads(str(cl.feature_description))
        fm = features.io.create_fm(fm_info)
        mpf = agglo.classifier_probability(fm, cl)

        watershed, dummy, prediction = self.gen_watershed()
        stack = agglo.Rag(watershed, prediction, mpf, feature_manager=fm, nozeros=True)
        self.assertEqual(stack.number_of_nodes(), 3630)
        stack.agglomerate(0.1)
        self.assertEqual(stack.number_of_nodes(), 88)
        stack.remove_inclusions()
        self.assertEqual(stack.number_of_nodes(), 86)
github janelia-flyem / gala / test / test_gala.py View on Github external
def setUp(self):
        self.probs2 = imio.read_h5_stack(rundir+'/test-05-probabilities.h5')
        self.probs1 = self.probs2[...,0]
        self.wss1 = imio.read_h5_stack(rundir+'/test-05-watershed.h5')
        self.f1, self.f2, self.f3 = features.moments.Manager(2, False), \
            features.histogram.Manager(3,compute_percentiles=[0.5]), \
            features.squiggliness.Manager(ndim=2)
        self.f4 = features.base.Composite(children=[self.f1, self.f2, self.f3])
github janelia-flyem / gala / gala / agglo.py View on Github external
def __init__(self, watershed=array([], label_dtype),
                 probabilities=array([]),
                 merge_priority_function=boundary_mean, gt_vol=None,
                 feature_manager=features.base.Null(), mask=None,
                 show_progress=False, connectivity=1,
                 channel_is_oriented=None, orientation_map=array([]),
                 normalize_probabilities=False, exclusions=array([]),
                 isfrozennode=None, isfrozenedge=None, use_slow=False,
                 update_unchanged_edges=False):

        super(Rag, self).__init__(weighted=False)
        self.show_progress = show_progress
        self.connectivity = connectivity
        self.pbar = (ip.StandardProgressBar() if self.show_progress
                     else ip.NoProgressBar())
        self.set_watershed(watershed, connectivity)
        self.set_probabilities(probabilities, normalize_probabilities)
        self.set_orientations(orientation_map, channel_is_oriented)
        self.merge_priority_function = merge_priority_function
        self.max_merge_score = -inf
github janelia-flyem / gala / benchmarks / bench_gala.py View on Github external
dd = os.path.abspath(os.path.join(rundir, '../tests/example-data'))


from time import process_time


@contextmanager
def timer():
    time = []
    t0 = process_time()
    yield time
    t1 = process_time()
    time.append(t1 - t0)


em = features.default.paper_em()


def trdata():
    wstr = imio.read_h5_stack(os.path.join(dd, 'train-ws.lzf.h5'))
    prtr = imio.read_h5_stack(os.path.join(dd, 'train-p1.lzf.h5'))
    gttr = imio.read_h5_stack(os.path.join(dd, 'train-gt.lzf.h5'))
    return wstr, prtr, gttr


def tsdata():
    wsts = imio.read_h5_stack(os.path.join(dd, 'test-ws.lzf.h5'))
    prts = imio.read_h5_stack(os.path.join(dd, 'test-p1.lzf.h5'))
    gtts = imio.read_h5_stack(os.path.join(dd, 'test-gt.lzf.h5'))
    return wsts, prts, gtts