How to use the numpyro.__version__.startswith function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / examples / baseball.py View on Github external
test, _ = fetch_test()
    at_bats, hits = train[:, 0], train[:, 1]
    season_at_bats, season_hits = test[:, 0], test[:, 1]
    for i, model in enumerate((fully_pooled,
                               not_pooled,
                               partially_pooled,
                               partially_pooled_with_logit,
                               )):
        rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1))
        zs = run_inference(model, at_bats, hits, rng_key, args)
        predict(model, at_bats, hits, zs, rng_key_predict, player_names)
        predict(model, season_at_bats, season_hits, zs, rng_key_predict, player_names, train=False)


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.4')
    parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
    parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=1500, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument('--algo', default='NUTS', type=str,
                        help='whether to run "HMC", "NUTS", or "SA"')
    parser.add_argument('-dp', '--disable-progbar', action="store_true", default=False,
                        help="whether to disable progress bar")
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)
github pyro-ppl / numpyro / benchmarks / hmm.py View on Github external
elif args.backend == 'stan':
        result = stan_inference(data, args)

    out_filename = 'hmm_{}{}_{}_seed={}.txt'.format(args.backend,
                                                    "(x64)" if args.x64 else "",
                                                    args.device,
                                                    args.seed)
    with open(os.path.join(DATA_DIR, out_filename), 'w') as f:
        f.write('\t'.join(['num_leapfrog', 'n_eff', 'total_time', 'time_per_leapfrog', 'time_per_eff_sample']))
        f.write('\n')
        f.write('\t'.join([str(x) for x in result]))
        f.write('\n')


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.3')
    parser = argparse.ArgumentParser(description="HMM example")
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument("--num-categories", nargs='?', default=3, type=int)
    parser.add_argument("--num-words", nargs='?', default=10, type=int)
    parser.add_argument("--num-supervised", nargs='?', default=100, type=int)
    parser.add_argument("--num-unsupervised", nargs='?', default=500, type=int)
    parser.add_argument("--seed", nargs='?', default=2019, type=int)
    parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
    parser.add_argument("--backend", default='numpyro', type=str, help='either "numpyro", "pyro", or "stan"')
    parser.add_argument("--x64", action="store_true")
    parser.add_argument("--disable-progbar", action="store_true")
    args = parser.parse_args()

    numpyro.enable_x64(args.x64)
github pyro-ppl / numpyro / examples / funnel.py View on Github external
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 8))

    ax1.plot(samples['x'][:, 0], samples['y'], "go", alpha=0.3)
    ax1.set(xlim=(-20, 20), ylim=(-9, 9), ylabel='y',
            title='Funnel samples with centered parameterization')

    ax2.plot(reparam_samples['x'][:, 0], reparam_samples['y'], "go", alpha=0.3)
    ax2.set(xlim=(-20, 20), ylim=(-9, 9), xlabel='x[0]', ylabel='y',
            title='Funnel samples with non-centered parameterization')

    plt.savefig('funnel_plot.pdf')
    plt.tight_layout()


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.4')
    parser = argparse.ArgumentParser(description="Non-centered reparameterization example")
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)
github pyro-ppl / numpyro / examples / cb.py View on Github external
print("total reward", total_reward)
    print("max reward", max_reward)

    print(results)

    log_file = 'cb.{}.num_arms_{}.num_data_{}.P_{}.S_{}.bias_{}.seed_{}'
    log_file = log_file.format(args['model'], args['num_arms'], args['num_data'], P,
                               args['active_dimensions'], args['bias'], args['seed'])

    with open(args['log_dir'] + log_file + '.pkl', 'wb') as f:
        pickle.dump(results, f, protocol=2)


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.4')
    parser = argparse.ArgumentParser(description="contextual bandits")
    parser.add_argument("-n", "--num-samples", nargs="?", default=60, type=int)
    parser.add_argument("-m", "--model", nargs="?", default="gp", type=str)
    parser.add_argument("--num-warmup", nargs='?', default=60, type=int)
    parser.add_argument("--mtd", nargs='?', default=5, type=int)
    parser.add_argument("--num-data", nargs='?', default=256, type=int)
    parser.add_argument("--num-arms", nargs='?', default=8, type=int)
    parser.add_argument("--num-dimensions", nargs='?', default=8, type=int)
    parser.add_argument("--seed", nargs='?', default=0, type=int)
    parser.add_argument("--bias", nargs='?', default=-2.0, type=float)
    parser.add_argument("--active-dimensions", nargs='?', default=4, type=int)
    parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
    parser.add_argument("--log-dir", default='./cbres4/', type=str)
    args = parser.parse_args()

    numpyro.set_platform(args.device)
github pyro-ppl / numpyro / benchmarks / covtype.py View on Github external
result = stan_inference(data, args)
    elif args.backend == 'edward':
        result = edward_inference(data, args)

    out_filename = 'covtype_{}_{}_seed={}.txt'.format(args.backend,
                                                      args.device,
                                                      args.seed)
    with open(os.path.join(DATA_DIR, out_filename), 'w') as f:
        f.write('\t'.join(['num_leapfrog', 'n_eff', 'total_time', 'time_per_leapfrog', 'time_per_eff_sample']))
        f.write('\n')
        f.write('\t'.join([str(x) for x in result]))
        f.write('\n')


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.1')
    parser = argparse.ArgumentParser(description="HMM example")
    parser.add_argument("-n", "--num-samples", nargs="?", default=30, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=0, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument("--seed", nargs='?', default=2019, type=int)
    parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
    parser.add_argument("--backend", default='numpyro', type=str, help='either "numpyro", "pyro", "stan", or "edward"')
    parser.add_argument("--x64", action="store_true")
    parser.add_argument("--disable-progbar", action="store_true")
    args = parser.parse_args()

    numpyro.enable_x64(args.x64)
    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)
    tt = torch.cuda if args.device == "gpu" else torch
    torch.set_default_tensor_type(tt.DoubleTensor if args.x64 else tt.FloatTensor)
github pyro-ppl / numpyro / examples / hmm.py View on Github external
x = onp.linspace(0, 1, 101)
    for i in range(transition_prob.shape[0]):
        for j in range(transition_prob.shape[1]):
            ax.plot(x, gaussian_kde(samples['transition_prob'][:, i, j])(x),
                    label="transition_prob[{}, {}], true value = {:.2f}"
                    .format(i, j, transition_prob[i, j]))
    ax.set(xlabel="Probability", ylabel="Frequency",
           title="Transition probability posterior")

    plt.savefig("hmm_plot.pdf")
    plt.tight_layout()


if __name__ == '__main__':
    assert numpyro.__version__.startswith('0.2.3')
    parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model')
    parser.add_argument('--num-categories', default=3, type=int)
    parser.add_argument('--num-words', default=10, type=int)
    parser.add_argument('--num-supervised', default=100, type=int)
    parser.add_argument('--num-unsupervised', default=500, type=int)
    parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int)
    parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)
github pyro-ppl / numpyro / examples / minipyro.py View on Github external
svi_state = fori_loop(0, args.num_steps, body_fn, svi_state)

    # Report the final values of the variational parameters
    # in the guide after training.
    params = svi.get_params(svi_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert np.abs(params["guide_loc"] - 3.0) < 0.1


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.3')
    parser = argparse.ArgumentParser(description="Mini Pyro demo")
    parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
    parser.add_argument("-n", "--num-steps", default=1001, type=int)
    parser.add_argument("-lr", "--learning-rate", default=0.02, type=float)
    args = parser.parse_args()
    main(args)
github pyro-ppl / numpyro / examples / gp.py View on Github external
fig, ax = plt.subplots(1, 1)

    # plot training data
    ax.plot(X, Y, 'kx')
    # plot 90% confidence level of predictions
    ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
    # plot mean prediction
    ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
    ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

    plt.savefig("gp_plot.pdf")
    plt.tight_layout()


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.3')
    parser = argparse.ArgumentParser(description="Gaussian Process example")
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument("--num-data", nargs='?', default=25, type=int)
    parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)
github pyro-ppl / numpyro / examples / covtype.py View on Github external
rng_key = random.PRNGKey(1)
    start = time.time()
    kernel = NUTS(model, trajectory_length=trajectory_length)
    mcmc = MCMC(kernel, 0, args.num_samples)
    mcmc.run(rng_key, features, labels)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)


def main(args):
    features, labels = _load_dataset()
    benchmark_hmc(args, features, labels)


if __name__ == '__main__':
    assert numpyro.__version__.startswith('0.2.3')
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples')
    parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")')
    parser.add_argument('--num-chains', nargs='?', default=1, type=int)
    parser.add_argument('--algo', default='NUTS', type=str, help='whether to run "HMC" or "NUTS"')
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)