How to use the cmdstanpy.cmdstan_args.CmdStanArgs function in cmdstanpy

To help you get started, we’ve selected a few cmdstanpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
CmdStanArgs(
                model_name='bernoulli',
                model_exe='bernoulli.exe',
                chain_ids=[1, 2, 3, 4],
                output_dir=fname,
                method_args=sampler_args,
            )
        if os.path.exists(fname):
            os.remove(fname)

        # TODO: read-only dir test for Windows - set ACLs, not mode
        if platform.system() == 'Darwin' or platform.system() == 'Linux':
            with self.assertRaises(ValueError):
                read_only = os.path.join(_TMPDIR, 'read_only')
                os.mkdir(read_only, mode=0o444)
                CmdStanArgs(
                    model_name='bernoulli',
                    model_exe='bernoulli.exe',
                    chain_ids=[1, 2, 3, 4],
                    output_dir=read_only,
                    method_args=sampler_args,
                )
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
jinits = os.path.join(DATAFILES_PATH, 'bernoulli.init.json')

        sampler_args = SamplerArgs()
        with self.assertRaises(ValueError):
            CmdStanArgs(
                model_name='bernoulli',
                model_exe=exe,
                chain_ids=None,
                seed=[1, 2, 3],
                data=jdata,
                inits=jinits,
                method_args=sampler_args,
            )

        with self.assertRaises(ValueError):
            CmdStanArgs(
                model_name='bernoulli',
                model_exe=exe,
                chain_ids=None,
                data=jdata,
                inits=[jinits],
                method_args=sampler_args,
            )
github stan-dev / cmdstanpy / test / test_runset.py View on Github external
def test_get_err_msgs(self):
        exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
        rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
        sampler_args = SamplerArgs()
        cmdstan_args = CmdStanArgs(
            model_name='logistic',
            model_exe=exe,
            chain_ids=[1, 2, 3],
            data=rdata,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=3)
        for i in range(3):
            runset._set_retcode(i, 70)
            stdout_file = 'chain-' + str(i + 1) + '-missing-data-stdout.txt'
            path = os.path.join(DATAFILES_PATH, stdout_file)
            runset._stdout_files[i] = path
        errs = '\n\t'.join(runset._get_err_msgs())
        self.assertIn('Exception', errs)
github stan-dev / cmdstanpy / test / test_runset.py View on Github external
def test_output_filenames(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs()
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        self.assertIn('bernoulli-', runset._csv_files[0])
        self.assertIn('-1-', runset._csv_files[0])
        self.assertIn('-4-', runset._csv_files[3])
github stan-dev / cmdstanpy / test / test_sample.py View on Github external
def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(
            iter_sampling=100, max_treedepth=11, adapt_delta=0.95
        )
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
def test_args_bad(self):
        sampler_args = SamplerArgs(iter_warmup=10, iter_sampling=20)

        with self.assertRaisesRegex(
            Exception, 'missing 2 required positional arguments'
        ):
            CmdStanArgs(model_name='bernoulli', model_exe='bernoulli.exe')

        with self.assertRaisesRegex(
            ValueError, 'no such file no/such/path/to.file'
        ):
            CmdStanArgs(
                model_name='bernoulli',
                model_exe='bernoulli.exe',
                chain_ids=[1, 2, 3, 4],
                data='no/such/path/to.file',
                method_args=sampler_args,
            )

        with self.assertRaisesRegex(ValueError, 'invalid chain_id'):
            CmdStanArgs(
                model_name='bernoulli',
                model_exe='bernoulli.exe',
                chain_ids=[1, 2, 3, -4],
                method_args=sampler_args,
            )

        with self.assertRaisesRegex(
github stan-dev / cmdstanpy / test / test_cmdstan_args.py View on Github external
self.assertIn('method=sample algorithm=hmc', ' '.join(cmd))

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[7, 11, 18, 29],
            data=jdata,
            method_args=sampler_args,
        )
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('id=7 random seed=', ' '.join(cmd))

        dirname = 'tmp' + str(time())
        if os.path.exists(dirname):
            os.rmdir(dirname)
        CmdStanArgs(
            model_name='bernoulli',
            model_exe='bernoulli.exe',
            chain_ids=[1, 2, 3, 4],
            output_dir=dirname,
            method_args=sampler_args,
        )
        self.assertTrue(os.path.exists(dirname))
        os.rmdir(dirname)
github stan-dev / cmdstanpy / cmdstanpy / model.py View on Github external
sample_csv_files = mcmc_sample.runset.csv_files
            sample_drawset = mcmc_sample.get_drawset()
            chains = mcmc_sample.chains
        elif isinstance(mcmc_sample, list):
            sample_csv_files = mcmc_sample
        else:
            raise ValueError(
                'MCMC sample must be either CmdStanMCMC object'
                ' or list of paths to sample csv_files.'
            )

        try:
            chains = len(sample_csv_files)
            if sample_drawset is None:  # assemble sample from csv files
                sampler_args = SamplerArgs()
                args = CmdStanArgs(
                    self._name,
                    self._exe_file,
                    chain_ids=[x + 1 for x in range(chains)],
                    method_args=sampler_args,
                )
                runset = RunSet(args=args, chains=chains)
                runset._csv_files = sample_csv_files
                sample_fit = CmdStanMCMC(runset)
                sample_fit._validate_csv_files()
                sample_drawset = sample_fit.get_drawset()
        except ValueError as e:
            raise ValueError(
                'Invalid mcmc_sample, error:\n\t{}\n\t'
                ' while processing files\n\t{}'.format(
                    repr(e), '\n\t'.join(sample_csv_files)
                )
github stan-dev / cmdstanpy / cmdstanpy / model.py View on Github external
iter_warmup=iter_warmup,
            iter_sampling=iter_sampling,
            save_warmup=save_warmup,
            thin=thin,
            max_treedepth=max_treedepth,
            metric=metric,
            step_size=step_size,
            adapt_engaged=adapt_engaged,
            adapt_delta=adapt_delta,
            adapt_init_phase=adapt_init_phase,
            adapt_metric_window=adapt_metric_window,
            adapt_step_size=adapt_step_size,
            fixed_param=fixed_param,
        )
        with MaybeDictToFilePath(data, inits) as (_data, _inits):
            args = CmdStanArgs(
                self._name,
                self._exe_file,
                chain_ids=chain_ids,
                data=_data,
                seed=seed,
                inits=_inits,
                output_dir=output_dir,
                save_diagnostics=save_diagnostics,
                method_args=sampler_args,
                refresh=refresh,
                logger=self._logger,
            )
            runset = RunSet(args=args, chains=chains)
            pbar = None
            all_pbars = []
github stan-dev / cmdstanpy / cmdstanpy / model.py View on Github external
where `` is set with `csv_basename`.

        :param algorithm: Algorithm to use. One of: "BFGS", "LBFGS", "Newton"

        :param init_alpha: Line search step size for first iteration

        :param iter: Total number of iterations

        :return: CmdStanMLE object
        """
        optimize_args = OptimizeArgs(
            algorithm=algorithm, init_alpha=init_alpha, iter=iter
        )

        with MaybeDictToFilePath(data, inits) as (_data, _inits):
            args = CmdStanArgs(
                self._name,
                self._exe_file,
                chain_ids=None,
                data=_data,
                seed=seed,
                inits=_inits,
                output_dir=output_dir,
                save_diagnostics=save_diagnostics,
                method_args=optimize_args,
            )

            dummy_chain_id = 0
            runset = RunSet(args=args, chains=1)
            self._run_cmdstan(runset, dummy_chain_id)

        if not runset._check_retcodes():