How to use the cmdstanpy.cmds.sample function in cmdstanpy

To help you get started, we’ve selected a few cmdstanpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def test_bernoulli_data(self):
        data_dict = {'N': 10, 'y': [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]}
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        output = os.path.join(TMPDIR, 'test3-bernoulli-output')
        model = compile_model(stan)
        post_sample = sample(model, data=data_dict, csv_output_file=output)
        for i in range(post_sample.chains):
            csv_file = post_sample.csv_files[i]
            txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
            self.assertTrue(os.path.exists(csv_file))
            self.assertTrue(os.path.exists(txt_file))
github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def test_bernoulli(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        exe = os.path.join(datafiles_path, 'bernoulli')
        if not os.path.exists(exe):
            compile_model(stan)
        model = Model(stan, exe_file=exe)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        post_sample = sample(
            model, chains=4, cores=2, seed=12345, sampling_iters=200, data=jdata
        )
        for i in range(post_sample.chains):
            csv_file = post_sample.csv_files[i]
            txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
            self.assertTrue(os.path.exists(csv_file))
            self.assertTrue(os.path.exists(txt_file))

        basename = 'bern_save_csvfiles_test'
        save_csvfiles(post_sample, datafiles_path, basename)  # good
        for i in range(post_sample.chains):
            csv_file = post_sample.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))

        with self.assertRaisesRegex(Exception, 'cannot save'):
            save_csvfiles(
github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def test_bernoulli_rdata(self):
        rdata = os.path.join(datafiles_path, 'bernoulli.data.R')
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        output = os.path.join(TMPDIR, 'test3-bernoulli-output')
        model = compile_model(stan)
        post_sample = sample(model, data=rdata, csv_output_file=output)
        for i in range(post_sample.chains):
            csv_file = post_sample.csv_files[i]
            txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
            self.assertTrue(os.path.exists(csv_file))
            self.assertTrue(os.path.exists(txt_file))
github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def test_bernoulli_2(self):
        # tempfile for outputs
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        exe = os.path.join(datafiles_path, 'bernoulli')
        if not os.path.exists(exe):
            compile_model(stan)
        model = Model(stan, exe_file=exe)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        post_sample = sample(
            model,
            chains=4,
            cores=2,
            seed=12345,
            sampling_iters=100,
            data=jdata,
            max_treedepth=11,
            adapt_delta=0.95,
        )
        for i in range(post_sample.chains):
            csv_file = post_sample.csv_files[i]
            txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
            self.assertTrue(os.path.exists(csv_file))
            self.assertTrue(os.path.exists(txt_file))
github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def test_bernoulli_1(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        exe = os.path.join(datafiles_path, 'bernoulli')
        if not os.path.exists(exe):
            compile_model(stan)
        model = Model(stan, exe_file=exe)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        output = os.path.join(datafiles_path, 'test1-bernoulli-output')
        post_sample = sample(
            model,
            chains=4,
            cores=2,
            seed=12345,
            sampling_iters=100,
            data=jdata,
            csv_output_file=output,
            max_treedepth=11,
            adapt_delta=0.95,
        )
        for i in range(post_sample.chains):
            csv_file = post_sample.csv_files[i]
            txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
            self.assertTrue(os.path.exists(csv_file))
            self.assertTrue(os.path.exists(txt_file))
github stan-dev / cmdstanpy / test / test_posterior_sample.py View on Github external
def test_postsample_good(self):
        column_names = ['lp__','accept_stat__','stepsize__','treedepth__',
                            'n_leapfrog__','divergent__','energy__', 'theta']
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        exe = os.path.join(datafiles_path, 'bernoulli')
        if not os.path.exists(exe):
            compile_model(stan)
        model = Model(stan, exe_file=exe)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        post_sample = sample(model, data_file=jdata)
        self.assertEqual(post_sample.chains,4)
        self.assertEqual(post_sample.draws,1000)
        self.assertEqual(post_sample.column_names, tuple(column_names))
        post_sample.sample
        self.assertEqual(post_sample.sample.shape,(1000, 4, 8))
        df = post_sample.summary()
        self.assertTrue(df.shape == (2, 9))
        capturedOutput = io.StringIO()
        sys.stdout = capturedOutput
        post_sample.diagnose()
        sys.stdout = sys.__stdout__
        self.assertEqual(capturedOutput.getvalue(), 'No problems detected.\n')
github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def test_bernoulli(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        exe = os.path.join(datafiles_path, 'bernoulli')
        if not os.path.exists(exe):
            compile_model(stan)
        model = Model(stan, exe_file=exe)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        post_sample = sample(
            model, chains=4, cores=2, seed=12345, sampling_iters=200, data=jdata
        )
        post_sample.assemble_sample()
        df = get_drawset(post_sample)
        self.assertEqual(
            df.shape,
            (
                post_sample.chains * post_sample.draws,
                len(post_sample.column_names),
            ),
github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def diagnose_no_problems(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        exe = os.path.join(datafiles_path, 'bernoulli')
        if not os.path.exists(exe):
            compile_model(stan)
        model = Model(stan, exe_file=exe)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        post_sample = sample(
            model, chains=4, cores=2, seed=12345, sampling_iters=200, data=jdata
        )
        capturedOutput = io.StringIO()
        sys.stdout = capturedOutput
        diagnose(post_sample)
        sys.stdout = sys.__stdout__
        self.assertEqual(capturedOutput.getvalue(), 'No problems detected.\n')
github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def test_missing_input(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        output = os.path.join(TMPDIR, 'test4-bernoulli-output')
        model = compile_model(stan)
        with self.assertRaisesRegex(Exception, 'Error during sampling'):
            post_sample = sample(model, csv_output_file=output)
github stan-dev / cmdstanpy / test / test_cmds.py View on Github external
def test_bernoulli(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        exe = os.path.join(datafiles_path, 'bernoulli')
        if not os.path.exists(exe):
            compile_model(stan)
        model = Model(stan, exe_file=exe)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        post_sample = sample(
            model, chains=4, cores=2, seed=12345, sampling_iters=200, data=jdata
        )
        df = summary(post_sample)
        self.assertTrue(df.shape == (2, 9))