Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_restart(backend, dtype):
# Run a sampler with the default backend.
b = backends.Backend()
run_sampler(b, dtype=dtype)
sampler1 = run_sampler(b, seed=None, dtype=dtype)
with backend() as be:
run_sampler(be, dtype=dtype)
sampler2 = run_sampler(be, seed=None, dtype=dtype)
# Check all of the components.
for k in ["chain", "log_prob", "blobs"]:
a = getattr(sampler1, "get_" + k)()
b = getattr(sampler2, "get_" + k)()
_custom_allclose(a, b)
last1 = sampler1.get_last_sample()
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
@pytest.mark.parametrize("backend,dtype,blobs",
product(other_backends, dtypes, [True, False]))
def test_backend(backend, dtype, blobs):
# Run a sampler with the default backend.
sampler1 = run_sampler(backends.Backend(), dtype=dtype, blobs=blobs)
with backend() as be:
sampler2 = run_sampler(be, dtype=dtype, blobs=blobs)
values = ["chain", "log_prob"]
if blobs:
values += ["blobs"]
else:
assert sampler1.get_blobs() is None
assert sampler2.get_blobs() is None
# Check all of the components.
for k in values:
a = getattr(sampler1, "get_" + k)()
b = getattr(sampler2, "get_" + k)()
_custom_allclose(a, b)
# continued_accepteds = continue_from_ps.get_value(qualifier='accepteds', **_skip_filter_checks)
# # continued_accepted [iterations, walkers]
continued_acceptance_fractions = continue_from_ps.get_value(qualifier='acceptance_fractions', **_skip_filter_checks)
# continued_acceptance_fractions [iterations, walkers]
continued_lnprobabilities = continue_from_ps.get_value(qualifier='lnprobabilities', **_skip_filter_checks)
# continued_lnprobabilities [iterations, walkers]
p0 = continued_samples[-1].T
# p0 [parameter, walkers]
nwalkers = int(p0.shape[-1])
start_iteration = continued_lnprobabilities.shape[0]
# fake a backend object from the previous solution so that emcee
# can continue from where it left off and still compute
# autocorrelation times, etc.
backend = emcee.backends.Backend()
backend.nwalkers = int(nwalkers)
backend.ndim = int(len(params_uniqueids))
backend.iteration = start_iteration
backend.accepted = np.asarray(continued_acceptance_fractions * start_iteration, dtype='int')
backend.chain = continued_samples
backend.log_prob = continued_lnprobabilities
backend.initialized = True
backend.random_state = None
esargs['backend'] = backend
params_twigs = [b.get_parameter(uniqueid=uniqueid, **_skip_filter_checks).twig for uniqueid in params_uniqueids]
esargs['pool'] = pool
esargs['nwalkers'] = nwalkers
esargs['ndim'] = len(params_uniqueids)
esargs['log_prob_fn'] = _lnprobability
# continued_accepteds = continue_from_ps.get_value(qualifier='accepteds', **_skip_filter_checks)
# # continued_accepted [iterations, walkers]
continued_acceptance_fractions = continue_from_ps.get_value(qualifier='acceptance_fractions', **_skip_filter_checks)
# continued_acceptance_fractions [iterations, walkers]
continued_lnprobabilities = continue_from_ps.get_value(qualifier='lnprobabilities', **_skip_filter_checks)
# continued_lnprobabilities [iterations, walkers]
p0 = continued_samples[-1].T
# p0 [parameter, walkers]
nwalkers = int(p0.shape[-1])
start_iteration = continued_lnprobabilities.shape[0]
# fake a backend object from the previous solution so that emcee
# can continue from where it left off and still compute
# autocorrelation times, etc.
backend = emcee.backends.Backend()
backend.nwalkers = int(nwalkers)
backend.ndim = int(len(params_uniqueids))
backend.iteration = start_iteration
backend.accepted = np.asarray(continued_acceptance_fractions * start_iteration, dtype='int')
backend.chain = continued_samples
backend.log_prob = continued_lnprobabilities
backend.initialized = True
backend.random_state = None
# reconstructing blobs will be messy, so we'll just get the correct
# shape with nones, but add to the existing dictionary later
backend.blobs = np.full(tuple(list(continued_samples.shape[:-1])+[2]), fill_value=None)
esargs['backend'] = backend
params_twigs = [b.get_parameter(uniqueid=uniqueid, **_skip_filter_checks).twig for uniqueid in params_uniqueids]