How to use the arviz.stats.stats_utils.not_valid function in arviz

To help you get started, we’ve selected a few arviz examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _ess_mad(ary, relative=False):
    """Calculate split-ess for mean absolute deviance."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan
    ary = abs(ary - np.median(ary))
    ary = ary <= np.median(ary)
    ary = _z_scale(_split_chains(ary))
    return _ess(ary, relative=relative)
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _rhat(ary):
    """Compute the rhat for a 2d array."""
    _numba_flag = Numba.numba_flag
    ary = np.asarray(ary, dtype=float)
    if _not_valid(ary, check_shape=False):
        return np.nan
    _, num_samples = ary.shape

    # Calculate chain mean
    chain_mean = np.mean(ary, axis=1)
    # Calculate chain variance
    chain_var = _numba_var(svar, np.var, ary, axis=1, ddof=1)
    # Calculate between-chain variance
    between_chain_variance = num_samples * _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
    # Calculate within-chain variance
    within_chain_variance = np.mean(chain_var)
    # Estimate of marginal posterior variance
    rhat_value = np.sqrt(
        (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
    )
    return rhat_value
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
# ess sd
    ess_sd_value = _ess_sd(ary)

    # ess bulk
    z_split = _z_scale(_split_chains(ary))
    ess_bulk_value = _ess(z_split)

    # ess tail
    quantile05, quantile95 = _quantile(ary, [0.05, 0.95])
    iquantile05 = ary <= quantile05
    quantile05_ess = _ess(_split_chains(iquantile05))
    iquantile95 = ary <= quantile95
    quantile95_ess = _ess(_split_chains(iquantile95))
    ess_tail_value = min(quantile05_ess, quantile95_ess)

    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        rhat_value = np.nan
    else:
        # r_hat
        rhat_bulk = _rhat(z_split)
        ary_folded = np.abs(ary - np.median(ary))
        rhat_tail = _rhat(_z_scale(_split_chains(ary_folded)))
        rhat_value = max(rhat_bulk, rhat_tail)

    # mcse_mean
    sd = np.std(ary, ddof=1)
    mcse_mean_value = sd / np.sqrt(ess_mean_value)

    # mcse_sd
    fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
    mcse_sd_value = sd * fac_mcse_sd
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _ess_sd(ary, relative=False):
    """Compute the effective sample size for the sd."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan
    ary = _split_chains(ary)
    return min(_ess(ary, relative=relative), _ess(ary ** 2, relative=relative))
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _ess_z_scale(ary, relative=False):
    """Calculate ess for z-scaLe."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan
    return _ess(_z_scale(_split_chains(ary)), relative=relative)
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _ess_tail(ary, prob=None, relative=False):
    """Compute the effective sample size for the tail.

    If `prob` defined, ess = min(qess(prob), qess(1-prob))
    """
    if prob is None:
        prob = (0.05, 0.95)
    elif not isinstance(prob, Sequence):
        prob = (prob, 1 - prob)

    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan

    prob_low, prob_high = prob
    quantile_low_ess = _ess_quantile(ary, prob_low, relative=relative)
    quantile_high_ess = _ess_quantile(ary, prob_high, relative=relative)
    return min(quantile_low_ess, quantile_high_ess)
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _ess_local(ary, prob, relative=False):
    """Compute the effective sample size for the specific residual."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan
    if prob is None:
        raise TypeError("Prob not defined.")
    if len(prob) != 2:
        raise ValueError("Prob argument in ess local must be upper and lower bound")
    quantile = _quantile(ary, prob)
    iquantile = (quantile[0] <= ary) & (ary <= quantile[1])
    return _ess(_split_chains(iquantile), relative=relative)
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _ess_folded(ary, relative=False):
    """Calculate split-ess for folded data."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan
    return _ess(_z_fold(_split_chains(ary)), relative=relative)
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _mcse_mean(ary):
    """Compute the Markov Chain mean error."""
    _numba_flag = Numba.numba_flag
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan
    ess = _ess_mean(ary)
    if _numba_flag:
        sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))
    else:
        sd = np.std(ary, ddof=1)
    mcse_mean_value = sd / np.sqrt(ess)
    return mcse_mean_value
github arviz-devs / arviz / arviz / stats / diagnostics.py View on Github external
def _rhat_z_scale(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat(_z_scale(_split_chains(ary)))