How to use the cmdstanpy.utils.EXTENSION function in cmdstanpy

To help you get started, we’ve selected a few cmdstanpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github stan-dev / cmdstanpy / test / test_model.py View on Github external
CODE = """data {
  int N;
  int y[N];
}
parameters {
  real theta;
}
model {
  theta ~ beta(1,1);  // uniform prior on interval 0,1
  y ~ bernoulli(theta);
}
"""

BERN_STAN = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
BERN_EXE = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)


class CmdStanModelTest(unittest.TestCase):

    # pylint: disable=no-self-use
    @pytest.fixture(scope='class', autouse=True)
    def do_clean_up(self):
        for root, _, files in os.walk(DATAFILES_PATH):
            for filename in files:
                _, ext = os.path.splitext(filename)
                if ext.lower() in ('.o', '.d', '.hpp', '.exe', ''):
                    filepath = os.path.join(root, filename)
                    os.remove(filepath)

    def show_cmdstan_version(self):
        print('\n\nCmdStan version: {}\n\n'.format(cmdstan_path()))
github stan-dev / cmdstanpy / test / test_runset.py View on Github external
def test_check_retcodes(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs()
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        self.assertIn('RunSet: chains=4', runset.__repr__())
        self.assertIn('method=sample', runset.__repr__())

        retcodes = runset._retcodes
        self.assertEqual(4, len(retcodes))
        for i in range(len(retcodes)):
github stan-dev / cmdstanpy / test / test_optimize.py View on Github external
def test_optimize_good_dict(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        model = CmdStanModel(stan_file=stan, exe_file=exe)
        with open(os.path.join(DATAFILES_PATH, 'bernoulli.data.json')) as fd:
            data = json.load(fd)
        with open(os.path.join(DATAFILES_PATH, 'bernoulli.init.json')) as fd:
            init = json.load(fd)
        mle = model.optimize(
            data=data,
            seed=1239812093,
            inits=init,
            algorithm='BFGS',
            init_alpha=0.001,
            iter=100,
        )
        # test numpy output
        self.assertAlmostEqual(mle.optimized_params_np[0], -5, places=2)
github stan-dev / cmdstanpy / test / test_sample.py View on Github external
def test_variables(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json')
        sampler_args = SamplerArgs(iter_sampling=20)
        cmdstan_args = CmdStanArgs(
            model_name='lotka-volterra',
            model_exe=exe,
            chain_ids=[1],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=1)
        runset._csv_files = [os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')]
        runset._set_retcode(0, 0)
        fit = CmdStanMCMC(runset)
        self.assertEqual(20, fit.num_draws)
github stan-dev / cmdstanpy / cmdstanpy / stanfit.py View on Github external
def summary(self) -> pd.DataFrame:
        """
        Run cmdstan/bin/stansummary over all output csv files.
        Echo stansummary stdout/stderr to console.
        Assemble csv tempfile contents into pandasDataFrame.
        """
        cmd_path = os.path.join(
            cmdstan_path(), 'bin', 'stansummary' + EXTENSION
        )
        tmp_csv_file = 'stansummary-{}-{}-chain-'.format(
            self.runset._args.model_name, self.runset.chains
        )
        tmp_csv_path = create_named_text_file(
            dir=_TMPDIR, prefix=tmp_csv_file, suffix='.csv'
        )
        cmd = [
            cmd_path,
            '--csv_file={}'.format(tmp_csv_path),
        ] + self.runset.csv_files
        do_command(cmd, logger=self.runset._logger)
        with open(tmp_csv_path, 'rb') as fd:
            summary_data = pd.read_csv(
                fd, delimiter=',', header=0, index_col=0, comment='#'
            )
github stan-dev / cmdstanpy / cmdstanpy / model.py View on Github external
if not (stanc_options is None and cpp_options is None):
            compiler_options = CompilerOptions(
                stanc_options=stanc_options, cpp_options=cpp_options
            )
            compiler_options.validate()
            if self._compiler_options is None:
                self._compiler_options = compiler_options
            elif override_options:
                self._compiler_options = compiler_options
            else:
                self._compiler_options.add(compiler_options)

        compilation_failed = False
        with TemporaryCopiedFile(self._stan_file) as (stan_file, is_copied):
            exe_file, _ = os.path.splitext(os.path.abspath(stan_file))
            exe_file = Path(exe_file).as_posix() + EXTENSION
            do_compile = True
            if os.path.exists(exe_file):
                src_time = os.path.getmtime(self._stan_file)
                exe_time = os.path.getmtime(exe_file)
                if exe_time > src_time and not force:
                    do_compile = False
                    self._logger.info('found newer exe file, not recompiling')

            if do_compile:
                self._logger.info(
                    'compiling stan program, exe file: %s', exe_file
                )
                if self._compiler_options is not None:
                    self._compiler_options.validate()
                    self._logger.info(
                        'compiler options: %s', self._compiler_options