How to use the lab.B.diag function in lab

To help you get started, we’ve selected a few lab examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wesselb / stheno / tests / View on Github external
def test_diag():
    # Test `Dense`.
    a = np.random.randn(5, 3)
    allclose(B.diag(Dense(a)), np.diag(a))

    # Test `Diagonal`.
    allclose(B.diag(Diagonal(np.array([1, 2, 3]))), [1, 2, 3])
    allclose(B.diag(Diagonal(np.array([1, 2, 3]), 2)), [1, 2])
    allclose(B.diag(Diagonal(np.array([1, 2, 3]), 4)), [1, 2, 3, 0])

    # Test `LowRank`.
    b = np.random.randn(10, 3)
    allclose(B.diag(LowRank(left=a, right=a)), np.diag(
    allclose(B.diag(LowRank(left=a, right=b)), np.diag(
    allclose(B.diag(LowRank(left=b, right=b)), np.diag(

    # Test `Constant`.
    allclose(B.diag(Constant(1, rows=3, cols=5)), np.ones(3))

    # Test `Woodbury`.
github wesselb / stheno / tests / View on Github external
p = GP(EQ(), graph=model)  # 1D
    p2 =  # 2D

    n = 5
    x = np.linspace(0, 10, n)[:, None]
    x1 = np.concatenate((x, np.random.randn(n, 1)), axis=1)
    x2 = np.concatenate((x, np.random.randn(n, 1)), axis=1)
    y = p2(x).sample()

    post = p.condition(p2(x1), y)
    allclose(post(x).mean, y)
    assert abs_err(B.diag(post(x).var)) <= 1e-10

    post = p.condition(p2(x2), y)
    allclose(post(x).mean, y)
    assert abs_err(B.diag(post(x).var)) <= 1e-10

    post = p2.condition(p(x), y)
    allclose(post(x1).mean, y)
    allclose(post(x2).mean, y)
    assert abs_err(B.diag(post(x1).var)) <= 1e-10
    assert abs_err(B.diag(post(x2).var)) <= 1e-10
github wesselb / stheno / stheno / View on Github external
def dense(a): return B.diag(B.diag(a), *B.shape(a))
github wesselb / stheno / stheno / View on Github external
def marginals(self):
        """Get the marginals.

            tuple: A tuple containing the predictive means and lower and
                upper 95% central credible interval bounds.
        mean = B.squeeze(self.mean)
        if self.p is None:
            vars = B.diag(self.var)
            vars = B.squeeze(B.dense(self.p.kernel.elwise(self.x)))
        error = 2 * B.sqrt(vars)
        return mean, mean - error, mean + error
github wesselb / stheno / stheno / View on Github external
def diag(a): return B.diag( + B.diag(a.diag)
github wesselb / stheno / stheno / View on Github external
def diag(diag, rows, cols=None):
    cols = rows if cols is None else cols

    # Cut the diagonal to accommodate the size.
    diag = diag[:B.minimum(rows, cols)]
    diag_len, dtype = B.shape(diag)[0], B.dtype(diag)

    # PyTorch incorrectly handles dimensions of size 0. Therefore, if the
    # numbers of extra columns and rows are `Number`s, which will be the case if
    # PyTorch is the backend, then perform a check to prevent appending tensors
    # with dimensions of size 0.

    # Start with just a diagonal matrix.
    res = B.diag(diag)

    # Pad extra columns if necessary.
    extra_cols = cols - diag_len
    if not (isinstance(extra_cols, Number) and extra_cols == 0):
        zeros = B.zeros(dtype, diag_len, extra_cols)
        res = B.concat(B.diag(diag), zeros, axis=1)

    # Pad extra rows if necessary.
    extra_rows = rows - diag_len
    if not (isinstance(extra_rows, Number) and extra_rows == 0):
        zeros = B.zeros(dtype, extra_rows, diag_len + extra_cols)
        res = B.concat(res, zeros, axis=0)

    return res
github wesselb / stheno / stheno / View on Github external
diag = diag[:B.minimum(rows, cols)]
    diag_len, dtype = B.shape(diag)[0], B.dtype(diag)

    # PyTorch incorrectly handles dimensions of size 0. Therefore, if the
    # numbers of extra columns and rows are `Number`s, which will be the case if
    # PyTorch is the backend, then perform a check to prevent appending tensors
    # with dimensions of size 0.

    # Start with just a diagonal matrix.
    res = B.diag(diag)

    # Pad extra columns if necessary.
    extra_cols = cols - diag_len
    if not (isinstance(extra_cols, Number) and extra_cols == 0):
        zeros = B.zeros(dtype, diag_len, extra_cols)
        res = B.concat(B.diag(diag), zeros, axis=1)

    # Pad extra rows if necessary.
    extra_rows = rows - diag_len
    if not (isinstance(extra_rows, Number) and extra_rows == 0):
        zeros = B.zeros(dtype, extra_rows, diag_len + extra_cols)
        res = B.concat(res, zeros, axis=0)

    return res
github wesselb / stheno / stheno / View on Github external
def matmul(a, b, tr_a=False, tr_b=False):
    a = B.transpose(a) if tr_a else a
    b = B.transpose(b) if tr_b else b
    diag_len = B.minimum(B.diag_len(a), B.diag_len(b))
    return Diagonal(B.diag(a)[:diag_len] * B.diag(b)[:diag_len],
github wesselb / stheno / stheno / View on Github external
def sum(a, axis=None):
    # Efficiently handle a number of common cases.
    if axis is None:
        return B.sum(B.diag(a))
    elif axis is 0:
        return B.concat(B.diag(a),
                        B.zeros(B.dtype(a), B.shape(a)[1] - B.diag_len(a)),
    elif axis is 1:
        return B.concat(B.diag(a),
                        B.zeros(B.dtype(a), B.shape(a)[0] - B.diag_len(a)),
        # Fall back to generic implementation.
        return B.sum.invoke(Dense)(a, axis=axis)