How to use the numpyro.set_platform function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int)

    parser.add_argument("-d", "--hidden-dim", default=16, type=int)
    parser.add_argument("--truncate", type=int)
    parser.add_argument("--num-sequences", default=17, type=int)
    parser.add_argument("--print-shapes", action="store_true")
    parser.add_argument("--kernel", default='nuts', type=str)
    parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    parser.add_argument('--time-compilation', action='store_true')
    parser.add_argument('--debug', action='store_true')

    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    if args.debug:
        with jax.disable_jit():
            main(args)
    else:
        main(args)
github pyro-ppl / numpyro / examples / hmm.py View on Github external
if __name__ == '__main__':
    assert numpyro.__version__.startswith('0.2.3')
    parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model')
    parser.add_argument('--num-categories', default=3, type=int)
    parser.add_argument('--num-words', default=10, type=int)
    parser.add_argument('--num-supervised', default=100, type=int)
    parser.add_argument('--num-unsupervised', default=500, type=int)
    parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int)
    parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)
github pyro-ppl / numpyro / benchmarks / sparse_regression.py View on Github external
parser = argparse.ArgumentParser(description="Sparse regression example")
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument("--num-data", nargs='?', default=100, type=int)
    parser.add_argument("--num-dimensions", nargs='?', default=50, type=int)
    parser.add_argument("--active-dimensions", nargs='?', default=3, type=int)
    parser.add_argument("--seed", nargs='?', default=2019, type=int)
    parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
    parser.add_argument("--backend", default='numpyro', type=str, help='either "numpyro" or "stan"')
    parser.add_argument("--x64", action="store_true")
    parser.add_argument("--disable-progbar", action="store_true")
    args = parser.parse_args()

    numpyro.enable_x64(args.x64)
    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)
github pyro-ppl / numpyro / examples / baseball.py View on Github external
if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.4')
    parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
    parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=1500, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument('--algo', default='NUTS', type=str,
                        help='whether to run "HMC", "NUTS", or "SA"')
    parser.add_argument('-dp', '--disable-progbar', action="store_true", default=False,
                        help="whether to disable progress bar")
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)
github pyro-ppl / numpyro / benchmarks / hmm.py View on Github external
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
    parser.add_argument("--num-chains", nargs='?', default=1, type=int)
    parser.add_argument("--num-categories", nargs='?', default=3, type=int)
    parser.add_argument("--num-words", nargs='?', default=10, type=int)
    parser.add_argument("--num-supervised", nargs='?', default=100, type=int)
    parser.add_argument("--num-unsupervised", nargs='?', default=500, type=int)
    parser.add_argument("--seed", nargs='?', default=2019, type=int)
    parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')
    parser.add_argument("--backend", default='numpyro', type=str, help='either "numpyro", "pyro", or "stan"')
    parser.add_argument("--x64", action="store_true")
    parser.add_argument("--disable-progbar", action="store_true")
    args = parser.parse_args()

    numpyro.enable_x64(args.x64)
    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)
    tt = torch.cuda if args.device == "gpu" else torch
    torch.set_default_tensor_type(tt.DoubleTensor if args.x64 else tt.FloatTensor)

    main(args)
github pyro-ppl / numpyro / examples / pairwise.py View on Github external
parser.add_argument("--num-chains", default=1, type=int)
    parser.add_argument("--mtd", default=5, type=int)
    parser.add_argument("--num-dimensions", default=100, type=int)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--lr", default=0.005, type=float)
    parser.add_argument("--cg-tol", default=0.001, type=float)
    parser.add_argument("--active-dimensions", default=14, type=int)
    parser.add_argument("--thinning", default=10, type=int)
    parser.add_argument("--device", default='gpu', type=str, help='use "cpu" or "gpu".')
    parser.add_argument("--likelihood", default='bernoulli', type=str)
    parser.add_argument("--log-dir", default='./', type=str)
    parser.add_argument("--results-file", default='results.out', type=str)
    parser.add_argument("--double", action="store_true")
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    if args.double:
        enable_x64()

    main(args)
github pyro-ppl / numpyro / examples / ucbadmit.py View on Github external
ax.legend()

    plt.savefig("ucbadmit_plot.pdf")
    plt.tight_layout()


if __name__ == '__main__':
    assert numpyro.__version__.startswith('0.2.3')
    parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC')
    parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int)
    parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
    parser.add_argument('--num-chains', nargs='?', default=1, type=int)
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)