How to use the pymc3.backends.base.MultiTrace function in pymc3

To help you get started, we’ve selected a few pymc3 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pymc-devs / pymc3 / pymc3 / backends / text.py View on Github external
A MultiTrace instance
    """
    files = glob(os.path.join(name, 'chain-*.csv'))

    if len(files) == 0:
        raise ValueError('No files present in directory {}'.format(name))

    straces = []
    for f in files:
        chain = int(os.path.splitext(f)[0].rsplit('-', 1)[1])
        model_vars_in_chain = _parse_chain_vars(f, model)
        strace = Text(name, model=model, vars=model_vars_in_chain)
        strace.chain = chain
        strace.filename = f
        straces.append(strace)
    return base.MultiTrace(straces)
github pymc-devs / pymc3 / pymc3 / sampling.py View on Github external
step,
        start,
        parallelize,
        tune=tune,
        model=model,
        random_seed=random_seed,
        progressbar=progressbar,
    )

    if progressbar:
        sampling = progress_bar(sampling, total=draws, display=progressbar)

    latest_traces = None
    for it, traces in enumerate(sampling):
        latest_traces = traces
    return MultiTrace(latest_traces)
github pymc-devs / pymc3 / pymc3 / diagnostics.py View on Github external
'Keyword argument varnames renamed to var_names, and will be removed in pymc3 3.8',
            DeprecationWarning
            )
    def rscore(x, num_samples):
        # Calculate between-chain variance
        B = num_samples * np.var(np.mean(x, axis=1), axis=0, ddof=1)

        # Calculate within-chain variance
        W = np.mean(np.var(x, axis=1, ddof=1), axis=0)

        # Estimate of marginal posterior variance
        Vhat = W * (num_samples - 1) / num_samples + B / num_samples

        return np.sqrt(Vhat / W)

    if not isinstance(mtrace, MultiTrace):
        # Return rscore for passed arrays
        return rscore(np.array(mtrace), mtrace.shape[1])

    if mtrace.nchains < 2:
        raise ValueError(
            'Gelman-Rubin diagnostic requires multiple chains '
            'of the same length.')

    if var_names is None:
        var_names = get_default_varnames(mtrace.varnames, include_transformed=include_transformed)

    Rhat = {}

    for var in var_names:
        x = np.array(mtrace.get_values(var, combine=False))
        num_samples = x.shape[1]
github pymc-devs / pymc3 / pymc3 / backends / hdf5.py View on Github external
----------
    name : str
        Path to HDF5 arrays file
    model : Model
        If None, the model is taken from the `with` context.

    Returns
    -------
    A MultiTrace instance
    """
    straces = []
    for chain in HDF5(name, model=model).chains:
        trace = HDF5(name, model=model)
        trace.chain = chain
        straces.append(trace)
    return base.MultiTrace(straces)
github pymc-devs / pymc3 / pymc3 / sampling.py View on Github external
def _choose_backend(trace, chain, shortcuts=None, **kwds):
    if isinstance(trace, BaseTrace):
        return trace
    if isinstance(trace, MultiTrace):
        return trace._straces[chain]
    if trace is None:
        return NDArray(**kwds)

    if shortcuts is None:
        shortcuts = pm.backends._shortcuts

    try:
        backend = shortcuts[trace]["backend"]
        name = shortcuts[trace]["name"]
        return backend(name, **kwds)
    except TypeError:
        return NDArray(vars=trace, **kwds)
    except KeyError:
        raise ValueError("Argument `trace` is invalid.")
github pymc-devs / pymc3 / pymc3 / sampling.py View on Github external
model=None,
    **kwargs
):
    skip_first = kwargs.get("skip_first", 0)

    sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
    _pbar_data = None
    _pbar_data = {"chain": chain, "divergences": 0}
    _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
    sampling = progress_bar(sampling, total=draws, display=progressbar)
    sampling.comment = _desc.format(**_pbar_data)
    try:
        strace = None
        for it, (strace, diverging) in enumerate(sampling):
            if it >= skip_first:
                trace = MultiTrace([strace])
                if diverging and _pbar_data is not None:
                    _pbar_data["divergences"] += 1
                    sampling.comment = _desc.format(**_pbar_data)
    except KeyboardInterrupt:
        pass
    return strace
github pymc-devs / pymc3 / pymc3 / backends / ndarray.py View on Github external
directory : str
        Path to a pymc3 serialized trace
    model : pm.Model (optional)
        Model used to create the trace.  Can also be inferred from context

    Returns
    -------
    pm.Multitrace that was saved in the directory
    """
    straces = []
    for subdir in glob.glob(os.path.join(directory, '*')):
        if os.path.isdir(subdir):
            straces.append(SerializeNDArray(subdir).load(model))
    if not straces:
        raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory)
    return base.MultiTrace(straces)
github pymc-devs / pymc3 / pymc3 / step_methods / ATMCMC.py View on Github external
while step.beta < 1.:
                print('Beta: ' + str(step.beta), ' Stage: ' + str(step.stage))
                if step.stage == 0:
                    # Initial stage
                    print('Sample initial stage: ...')
                    stage_path = homepath + '/stage_' + str(step.stage)
                    trace = Text(stage_path, model=model)
                    initial = _iter_initial(step, chain=chain, trace=trace)
                    progress = progress_bar(step.n_chains)
                    try:
                        for i, strace in enumerate(initial):
                            if progressbar:
                                progress.update(i)
                    except KeyboardInterrupt:
                        strace.close()
                    mtrace = MultiTrace([strace])
                    step.population, step.array_population, step.likelihoods = \
                        step.select_end_points(mtrace)
                    step.beta, step.old_beta, step.weights = step.calc_beta()
                    step.covariance = step.calc_covariance()
                    step.res_indx = step.resample()
                    step.stage += 1
                    del(strace, mtrace, trace)
                else:
                    if progressbar and njobs > 1:
                        progressbar = False
                    # Metropolis sampling intermediate stages
                    stage_path = homepath + '/stage_' + str(step.stage)
                    step.proposal_dist = MvNPd(step.covariance)

                    sample_args = {
                        'draws': n_steps,