How to use the geomstats.backend.squeeze function in geomstats

To help you get started, we’ve selected a few geomstats examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geomstats / geomstats / geomstats / special_orthogonal_group.py View on Github external
rot_mat = self.projection(rot_mat)

        if self.n == 3:
            trace = gs.trace(rot_mat, axis1=1, axis2=2)
            trace = gs.to_ndarray(trace, to_ndim=2, axis=1)
            assert trace.shape == (n_rot_mats, 1), trace.shape

            cos_angle = .5 * (trace - 1)
            cos_angle = gs.clip(cos_angle, -1, 1)
            angle = gs.arccos(cos_angle)

            rot_mat_transpose = gs.transpose(rot_mat, axes=(0, 2, 1))
            rot_vec = self.vector_from_skew_matrix(rot_mat - rot_mat_transpose)

            mask_0 = gs.isclose(angle, 0.)
            mask_0 = gs.squeeze(mask_0, axis=1)
            rot_vec[mask_0] = (rot_vec[mask_0]
                               * (.5 - (trace[mask_0] - 3.) / 12.))

            mask_pi = gs.isclose(angle, gs.pi)
            mask_pi = gs.squeeze(mask_pi, axis=1)

            # choose the largest diagonal element
            # to avoid a square root of a negative number
            a = gs.array(0)
            if gs.any(mask_pi):
                a = gs.argmax(gs.diagonal(rot_mat[mask_pi], axis1=1, axis2=2))
            b = (a + 1) % 3
            c = (a + 2) % 3

            # compute the axis vector
            sq_root = gs.sqrt((rot_mat[mask_pi, a, a]
github geomstats / geomstats / geomstats / geometry / poincare_ball.py View on Github external
grad_term_211 = \
            gs.exp((prod_alpha_sigma) ** 2) \
            * (1 + gs.erf(prod_alpha_sigma)) \
            * gs.einsum('ij,j->ij', sigma_repeated, alpha ** 2) * 2

        grad_term_212 = gs.repeat(gs.expand_dims((2 / gs.sqrt(gs.pi))
                                                 * alpha, axis=0),
                                  variances.shape[0], axis=0)

        grad_term_22 = grad_term_211 + grad_term_212
        grad_term_22 = gs.einsum('ij, j->ij', grad_term_22, beta)
        grad_term_22 = gs.sum(grad_term_22, axis=-1, keepdims=True)

        norm_factor_gradient = grad_term_1 + (grad_term_21 * grad_term_22)

        return gs.squeeze(norm_factor), gs.squeeze(norm_factor_gradient)
github geomstats / geomstats / geomstats / special_orthogonal_group.py View on Github external
coef_1 += mask_else_float * (angle / 2) / gs.tan(angle / 2)
                coef_2 += mask_else_float * (1 - coef_1) / angle ** 2

                jacobian = gs.zeros((n_points, self.dimension, self.dimension))
                for i in range(n_points):
                    mask_i_float = get_mask_i_float(i, n_points)

                    sign = - 1
                    if left_or_right == 'left':
                        sign = + 1

                    jacobian_i = (
                        coef_1[i] * gs.eye(self.dimension)
                        + coef_2[i] * gs.outer(point[i], point[i])
                        + sign * self.skew_matrix_from_vector(point[i]) / 2)
                    jacobian_i = gs.squeeze(jacobian_i, axis=0)

                    jacobian += gs.einsum(
                        'n,ij->nij', mask_i_float, jacobian_i)

            else:
                if left_or_right == 'right':
                    raise NotImplementedError(
                        'The jacobian of the right translation'
                        ' is not implemented.')
                jacobian = self.matrix_from_rotation_vector(point)

            assert gs.ndim(jacobian) == 3

        elif point_type == 'matrix':
            raise NotImplementedError()
github geomstats / geomstats / geomstats / geometry / hypersphere.py View on Github external
raise NotImplementedError(
                'The Christoffel symbols are only implemented'
                ' for spherical coordinates in the 2-sphere')

        point = gs.to_ndarray(point, to_ndim=2)
        christoffel = []
        for sample in point:
            gamma_0 = gs.array(
                [[0, 0], [0, - gs.sin(sample[0]) * gs.cos(sample[0])]])
            gamma_1 = gs.array([[0, gs.cos(sample[0]) / gs.sin(sample[0])],
                                [gs.cos(sample[0]) / gs.sin(sample[0]), 0]])
            christoffel.append(gs.stack([gamma_0, gamma_1]))

        christoffel = gs.stack(christoffel)
        if gs.ndim(christoffel) == 4 and gs.shape(christoffel)[0] == 1:
            christoffel = gs.squeeze(christoffel, axis=0)
        return christoffel
github geomstats / geomstats / geomstats / geometry / poincare_ball.py View on Github external
point_a_norm = gs.clip(gs.sum(
                point_a_flatten ** 2, -1), 0., 1 - EPSILON)
            point_b_norm = gs.clip(gs.sum(
                point_b_flatten ** 2, -1), 0., 1 - EPSILON)

            square_diff = (point_a_flatten - point_b_flatten) ** 2

            diff_norm = gs.sum(square_diff, -1)
            norm_function = 1 + 2 * \
                diff_norm / ((1 - point_a_norm) * (1 - point_b_norm))

            dist = gs.log(norm_function + gs.sqrt(norm_function ** 2 - 1))
            dist *= self.scale
            dist = gs.reshape(dist, (point_a.shape[0], point_b.shape[0]))
            dist = gs.squeeze(dist)

        elif point_a.shape == point_b.shape:
            dist = self.dist(point_a, point_b)

        return dist
github geomstats / geomstats / geomstats / geometry / special_orthogonal3.py View on Github external
mask_0_float = gs.cast(mask_0, gs.float32) + self.epsilon

        coef_1 += mask_0_float * (1. - (angle ** 2) / 6.)
        coef_2 += mask_0_float * (1. / 2. - angle ** 2)

        # This avoids dividing by 0.
        mask_else = ~mask_0
        mask_else_float = gs.cast(mask_else, gs.float32) + self.epsilon

        angle += mask_0_float

        coef_1 += mask_else_float * (gs.sin(angle) / angle)
        coef_2 += mask_else_float * (
            (1. - gs.cos(angle)) / (angle ** 2))

        coef_1 = gs.squeeze(coef_1, axis=1)
        coef_2 = gs.squeeze(coef_2, axis=1)
        term_1 = (gs.eye(self.dim)
                  + gs.einsum('n,njk->njk', coef_1, skew_rot_vec))

        squared_skew_rot_vec = gs.einsum(
            'nij,njk->nik', skew_rot_vec, skew_rot_vec)

        term_2 = gs.einsum('n,njk->njk', coef_2, squared_skew_rot_vec)

        return term_1 + term_2
github geomstats / geomstats / geomstats / geometry / hyperbolic.py View on Github external
bound: float, optional
            Bound defining the hypersquare in which to sample uniformly.

        Returns
        -------
        samples : array-like, shape=[..., dim + 1]
            Samples in hyperbolic space.
        """
        size = (n_samples, self.dim)
        samples = bound * 2. * (gs.random.rand(*size) - 0.5)

        samples = Hyperbolic.change_coordinates_system(
            samples, 'intrinsic', self.coords_type)

        if n_samples == 1:
            samples = gs.squeeze(samples, axis=0)
        return samples