How to use the pgmpy.utils.pinverse function in pgmpy

To help you get started, we’ve selected a few pgmpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pgmpy / pgmpy / pgmpy / estimators / SEMEstimator.py View on Github external
def _get_implied_cov(self, B, zeta):
        """
        Computes the implied covariance matrix from the given parameters.
        """
        B_masked = torch.mul(B, self.B_mask) + self.B_fixed_mask
        B_inv = pinverse(self.B_eye - B_masked)
        zeta_masked = torch.mul(zeta, self.zeta_mask) + self.zeta_fixed_mask

        return self.wedge_y @ B_inv @ zeta_masked @ B_inv.t() @ self.wedge_y.t()
github pgmpy / pgmpy / pgmpy / estimators / SEMEstimator.py View on Github external
Parameters
        ----------
        params: dict
            params contain all the variables which are updated in each iteration of the
            optimization.

        loss_args: dict
            loss_args contain all the variable which are not updated in each iteration but
            are required to compute the loss.

        Returns
        -------
        torch.tensor: The loss value for the given params and loss_args
        """
        S = loss_args["S"]
        W_inv = pinverse(loss_args["W"])
        sigma = self._get_implied_cov(params["B"], params["zeta"])
        return ((S - sigma) @ W_inv).pow(2).trace()
github pgmpy / pgmpy / pgmpy / estimators / SEMEstimator.py View on Github external
optimization.

        loss_args: dict
            loss_args contain all the variable which are not updated in each iteration but
            are required to compute the loss.

        Returns
        -------
        torch.tensor: The loss value for the given params and loss_args
        """
        S = loss_args["S"]
        sigma = self._get_implied_cov(params["B"], params["zeta"])

        return (
            sigma.det().clamp(min=1e-4).log()
            + (S @ pinverse(sigma)).trace()
            - S.logdet()
            - len(self.model.y)
        )