How to use the gudhi.SimplexTree function in gudhi

To help you get started, we’ve selected a few gudhi examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github MathieuCarriere / sklearn-tda / sklearn_tda / clustering.py View on Github external
"""
        Compute the extended persistence diagrams of the Mapper simplicial complex associated to each color function.

        Returns:
            list_dgm (list of gudhi persistence diagrams): output extended persistence diagrams. There is one per color function.
        """
        num_cols, list_dgm = self.colors.shape[1], []

        # Compute an extended persistence diagram for each color
        for c in range(num_cols):

            # Retrieve all color values
            col_vals = {node_name: self.node_info_[node_name]["colors"][c] for node_name in self.node_info_.keys()}
            
            # Create a new simplicial complex by coning the Mapper with an extra point with name -2
            st = gd.SimplexTree()
            list_simplices, list_vertices = self.mapper_.get_skeleton(1), self.mapper_.get_skeleton(0)
            for (simplex, f) in list_simplices:
                st.insert(simplex + [-2], filtration=-3)

            # Assign ascending filtration values on the original simplices and descending filtration values on the coned simplices 
            min_val, max_val = min(col_vals), max(col_vals)
            for (vertex, f) in list_vertices:
                if st.find(vertex):
                    st.assign_filtration(vertex,        filtration = -2 + (col_vals[vertex[0]]-min_val)/(max_val-min_val))
                    st.assign_filtration(vertex + [-2], filtration =  2 - (col_vals[vertex[0]]-min_val)/(max_val-min_val))

            # Compute persistence
            st.make_filtration_non_decreasing()
            dgm = st.persistence()

            # Output extended persistence diagrams
github GUDHI / gudhi-devel / src / python / example / plot_simplex_tree_dim012.py View on Github external
#!/usr/bin/env python
import numpy as np
import gudhi

# Coordinates of the points
points=np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1],[1,1,1],[1,1,0],[0,1,1]])
# Build the simplicial complex with a tetrahedon, an edge and an isolated vertex
cplx=gudhi.SimplexTree()
cplx.insert([1,2,3,5])
cplx.insert([4,6])
cplx.insert([0])
# List of triangles (point indices)
triangles = np.array([s[0] for s in cplx.get_skeleton(2) if len(s[0])==3])
# List of edges (point coordinates)
edges = []
for s in cplx.get_skeleton(1):
    e = s[0]
    if len(e) == 2:
        edges.append(points[[e[0],e[1]]])

## With plotly
import plotly.graph_objects as go
# Plot triangles
f2 = go.Mesh3d(
github MathieuCarriere / perslay / perslay / utils.py View on Github external
def apply_graph_extended_persistence(A, filtration_val, basesimplex):
    num_vertices = A.shape[0]
    (xs, ys) = np.where(np.triu(A))
    num_edges = len(xs)

    if len(filtration_val.shape) == 1:
        min_val, max_val = filtration_val.min(), filtration_val.max()
    else:
        min_val = min([filtration_val[xs[i], ys[i]] for i in range(num_edges)])
        max_val = max([filtration_val[xs[i], ys[i]] for i in range(num_edges)])

    st = gd.SimplexTree()
    st.set_dimension(2)

    for simplex, filt in basesimplex:
        st.insert(simplex=simplex + [-2], filtration=-3)

    if len(filtration_val.shape) == 1:
        if max_val == min_val:
            fa = -.5 * np.ones(filtration_val.shape)
            fd = .5 * np.ones(filtration_val.shape)
        else:
            fa = -2 + (filtration_val - min_val) / (max_val - min_val)
            fd = 2 - (filtration_val - min_val) / (max_val - min_val)
        for vid in range(num_vertices):
            st.assign_filtration(simplex=[vid], filtration=fa[vid])
            st.assign_filtration(simplex=[vid, -2], filtration=fd[vid])
    else:
github MathieuCarriere / sklearn-tda / sklearn_tda / code.py View on Github external
print("clusters in preimage " + str(preimage) + " = " + str(clusters))

            num_clus_pre = np.max(clusters) + 1
            for i in range(num_clus_pre):
                subpopulation = idxs[clusters == i]
                color_val = np.mean(self.color[subpopulation])
                clus_color[clus_base + i] = color_val
                clus_size[clus_base + i] = len(subpopulation)

            for i in range(clusters.shape[0]):
                if clusters[i] != -1:
                    cover[idxs[i]].append(clus_base + clusters[i])

            clus_base += np.max(clusters) + 1

        self.st_ = gd.SimplexTree()
        for i in range(num_pts):
            num_clus_i = len(cover[i])
            for j in range(num_clus_i):
                self.st_.insert([cover[i][j]])
            self.st_.insert(cover[i])

        self.graph_ = []
        for simplex in self.st_.get_skeleton(2):
            print(simplex)
            if len(simplex[0]) > 1:
                idx1, idx2 = simplex[0][0], simplex[0][1]
                if self.mask <= idx1 and self.mask <= idx2:
                    self.graph_.append([simplex[0]])
            else:
                clus_idx = simplex[0][0]
                if self.mask <= clus_size[clus_idx]:
github MathieuCarriere / perslay / expe / utils.py View on Github external
def _get_base_simplex(A):
    num_vertices = A.shape[0]
    st = gd.SimplexTree()
    for i in range(num_vertices):
        st.insert([i], filtration=-1e10)
        for j in range(i + 1, num_vertices):
            if A[i, j] > 0:
                st.insert([i, j], filtration=-1e10)
    return st.get_filtration()
github MathieuCarriere / perslay / perslay / utils.py View on Github external
def get_base_simplex(A):
    num_vertices = A.shape[0]
    st = gd.SimplexTree()
    for i in range(num_vertices):
        st.insert([i], filtration=-1e10)
        for j in range(i + 1, num_vertices):
            if A[i, j] > 0:
                st.insert([i, j], filtration=-1e10)
    return st.get_filtration()
github MathieuCarriere / sklearn-tda / sklearn_tda / clustering.py View on Github external
num_pts, num_filters, num_colors = self.filters.shape[0], self.filters.shape[1], self.colors.shape[1]

        # If some resolutions are not specified, automatically compute them
        if np.any(np.isnan(self.resolutions)):
            delta, resolutions = self.get_optimal_parameters_for_agglomerative_clustering(X=X, beta=0., C=10, N=100)
            if self.input == "point cloud":
                self.clustering = AgglomerativeClustering(n_clusters=None, linkage="single", distance_threshold=delta, affinity="euclidean")  
            else:
                self.clustering = AgglomerativeClustering(n_clusters=None, linkage="single", distance_threshold=delta, affinity="precomputed")
            self.resolutions = np.where(np.isnan(self.resolutions), resolutions, self.resolutions)

        # If some filter limits are unspecified, automatically compute them
        self.filter_bnds = np.where(np.isnan(self.filter_bnds), np.hstack([np.min(self.filters, axis=0)[:,np.newaxis], np.max(self.filters, axis=0)[:,np.newaxis]]), self.filter_bnds)

        # Initialize attributes
        self.mapper_, self.node_info_ = gd.SimplexTree(), {}

        if np.all(self.gains < .5):
            
            # Compute which points fall in which patch or patch intersections
            interval_inds, intersec_inds = np.empty(self.filters.shape), np.empty(self.filters.shape)
            for i in range(num_filters):
                f, r, g = self.filters[:,i], self.resolutions[i], self.gains[i]
                min_f, max_f = self.filter_bnds[i,0], np.nextafter(self.filter_bnds[i,1], np.inf)
                interval_endpoints, l = np.linspace(min_f, max_f, num=r+1, retstep=True)
                intersec_endpoints = []
                for j in range(1, len(interval_endpoints)-1):
                    intersec_endpoints.append(interval_endpoints[j] - g*l / (2 - 2*g))
                    intersec_endpoints.append(interval_endpoints[j] + g*l / (2 - 2*g))
                interval_inds[:,i] = np.digitize(f, interval_endpoints)
                intersec_inds[:,i] = 0.5 * (np.digitize(f, intersec_endpoints) + 1)
github MathieuCarriere / perslay / expe / utils.py View on Github external
def _apply_graph_extended_persistence(A, filtration_val, basesimplex):
    num_vertices = A.shape[0]
    (xs, ys) = np.where(np.triu(A))
    num_edges = len(xs)

    if len(filtration_val.shape) == 1:
        min_val, max_val = filtration_val.min(), filtration_val.max()
    else:
        min_val = min([filtration_val[xs[i], ys[i]] for i in range(num_edges)])
        max_val = max([filtration_val[xs[i], ys[i]] for i in range(num_edges)])

    st = gd.SimplexTree()
    st.set_dimension(2)

    for simplex, filt in basesimplex:
        st.insert(simplex=simplex + [-2], filtration=-3)

    if len(filtration_val.shape) == 1:
        if max_val == min_val:
            fa = -.5 * np.ones(filtration_val.shape)
            fd = .5 * np.ones(filtration_val.shape)
        else:
            fa = -2 + (filtration_val - min_val) / (max_val - min_val)
            fd = 2 - (filtration_val - min_val) / (max_val - min_val)
        for vid in range(num_vertices):
            st.assign_filtration(simplex=[vid], filtration=fa[vid])
            st.assign_filtration(simplex=[vid, -2], filtration=fd[vid])
    else:
github GUDHI / gudhi-devel / src / cython / example / simplex_tree_example.py View on Github external
but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program.  If not, see .
"""

__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "GPL v3"

print("#####################################################################")
print("SimplexTree creation from insertion")

st = gudhi.SimplexTree()

if st.insert([0, 1]):
    print("Inserted !!")
else:
    print("Not inserted...")

if st.find([0, 1]):
    print("Found !!")
else:
    print("Not found...")

if st.insert([0, 1, 2], filtration=4.0):
    print("Inserted !!")
else:
    print("Not inserted...")