Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
CODE = """data {
int N;
int y[N];
}
parameters {
real theta;
}
model {
theta ~ beta(1,1); // uniform prior on interval 0,1
y ~ bernoulli(theta);
}
"""
BERN_STAN = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
BERN_EXE = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
class CmdStanModelTest(unittest.TestCase):
# pylint: disable=no-self-use
@pytest.fixture(scope='class', autouse=True)
def do_clean_up(self):
for root, _, files in os.walk(DATAFILES_PATH):
for filename in files:
_, ext = os.path.splitext(filename)
if ext.lower() in ('.o', '.d', '.hpp', '.exe', ''):
filepath = os.path.join(root, filename)
os.remove(filepath)
def show_cmdstan_version(self):
print('\n\nCmdStan version: {}\n\n'.format(cmdstan_path()))
def test_check_retcodes(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = SamplerArgs()
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[1, 2, 3, 4],
data=jdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=4)
self.assertIn('RunSet: chains=4', runset.__repr__())
self.assertIn('method=sample', runset.__repr__())
retcodes = runset._retcodes
self.assertEqual(4, len(retcodes))
for i in range(len(retcodes)):
def test_optimize_good_dict(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
model = CmdStanModel(stan_file=stan, exe_file=exe)
with open(os.path.join(DATAFILES_PATH, 'bernoulli.data.json')) as fd:
data = json.load(fd)
with open(os.path.join(DATAFILES_PATH, 'bernoulli.init.json')) as fd:
init = json.load(fd)
mle = model.optimize(
data=data,
seed=1239812093,
inits=init,
algorithm='BFGS',
init_alpha=0.001,
iter=100,
)
# test numpy output
self.assertAlmostEqual(mle.optimized_params_np[0], -5, places=2)
def test_variables(self):
# construct fit using existing sampler output
exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json')
sampler_args = SamplerArgs(iter_sampling=20)
cmdstan_args = CmdStanArgs(
model_name='lotka-volterra',
model_exe=exe,
chain_ids=[1],
seed=12345,
data=jdata,
output_dir=DATAFILES_PATH,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args, chains=1)
runset._csv_files = [os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')]
runset._set_retcode(0, 0)
fit = CmdStanMCMC(runset)
self.assertEqual(20, fit.num_draws)
def summary(self) -> pd.DataFrame:
"""
Run cmdstan/bin/stansummary over all output csv files.
Echo stansummary stdout/stderr to console.
Assemble csv tempfile contents into pandasDataFrame.
"""
cmd_path = os.path.join(
cmdstan_path(), 'bin', 'stansummary' + EXTENSION
)
tmp_csv_file = 'stansummary-{}-{}-chain-'.format(
self.runset._args.model_name, self.runset.chains
)
tmp_csv_path = create_named_text_file(
dir=_TMPDIR, prefix=tmp_csv_file, suffix='.csv'
)
cmd = [
cmd_path,
'--csv_file={}'.format(tmp_csv_path),
] + self.runset.csv_files
do_command(cmd, logger=self.runset._logger)
with open(tmp_csv_path, 'rb') as fd:
summary_data = pd.read_csv(
fd, delimiter=',', header=0, index_col=0, comment='#'
)
if not (stanc_options is None and cpp_options is None):
compiler_options = CompilerOptions(
stanc_options=stanc_options, cpp_options=cpp_options
)
compiler_options.validate()
if self._compiler_options is None:
self._compiler_options = compiler_options
elif override_options:
self._compiler_options = compiler_options
else:
self._compiler_options.add(compiler_options)
compilation_failed = False
with TemporaryCopiedFile(self._stan_file) as (stan_file, is_copied):
exe_file, _ = os.path.splitext(os.path.abspath(stan_file))
exe_file = Path(exe_file).as_posix() + EXTENSION
do_compile = True
if os.path.exists(exe_file):
src_time = os.path.getmtime(self._stan_file)
exe_time = os.path.getmtime(exe_file)
if exe_time > src_time and not force:
do_compile = False
self._logger.info('found newer exe file, not recompiling')
if do_compile:
self._logger.info(
'compiling stan program, exe file: %s', exe_file
)
if self._compiler_options is not None:
self._compiler_options.validate()
self._logger.info(
'compiler options: %s', self._compiler_options