Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_bernoulli_2(self):
# tempfile for outputs
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
post_sample = sample(
model,
chains=4,
cores=2,
seed=12345,
sampling_iters=100,
data=jdata,
max_treedepth=11,
adapt_delta=0.95,
)
for i in range(post_sample.chains):
csv_file = post_sample.csv_files[i]
txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
self.assertTrue(os.path.exists(csv_file))
self.assertTrue(os.path.exists(txt_file))
def test_postsample_good(self):
column_names = ['lp__','accept_stat__','stepsize__','treedepth__',
'n_leapfrog__','divergent__','energy__', 'theta']
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
post_sample = sample(model, data_file=jdata)
self.assertEqual(post_sample.chains,4)
self.assertEqual(post_sample.draws,1000)
self.assertEqual(post_sample.column_names, tuple(column_names))
post_sample.sample
self.assertEqual(post_sample.sample.shape,(1000, 4, 8))
df = post_sample.summary()
self.assertTrue(df.shape == (2, 9))
capturedOutput = io.StringIO()
sys.stdout = capturedOutput
post_sample.diagnose()
sys.stdout = sys.__stdout__
self.assertEqual(capturedOutput.getvalue(), 'No problems detected.\n')
def diagnose_no_problems(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
post_sample = sample(
model, chains=4, cores=2, seed=12345, sampling_iters=200, data=jdata
)
capturedOutput = io.StringIO()
sys.stdout = capturedOutput
diagnose(post_sample)
sys.stdout = sys.__stdout__
self.assertEqual(capturedOutput.getvalue(), 'No problems detected.\n')
def test_bernoulli_1(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
output = os.path.join(datafiles_path, 'test1-bernoulli-output')
post_sample = sample(
model,
chains=4,
cores=2,
seed=12345,
sampling_iters=100,
data=jdata,
csv_output_file=output,
max_treedepth=11,
adapt_delta=0.95,
)
for i in range(post_sample.chains):
csv_file = post_sample.csv_files[i]
txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
def test_diagnose_divergences(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
model = Model(exe_file=exe, stan_file=stan)
output = os.path.join(
datafiles_path, 'diagnose-good', 'corr_gauss_depth8'
)
args = SamplerArgs(model, chain_ids=[1], output_file=output)
runset = RunSet(args=args, chains=1)
# TODO - use cmdstan test files instead
expected = ''.join(
[
'424 of 1000 (42%) transitions hit the maximum ',
'treedepth limit of 8, or 2^8 leapfrog steps. ',
'Trajectories that are prematurely terminated ',
'due to this limit will result in slow ',
'exploration and you should increase the ',
'limit to ensure optimal performance.\n',
]
def test_bernoulli(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
if not os.path.exists(exe):
compile_model(stan)
model = Model(stan, exe_file=exe)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
post_sample = sample(
model, chains=4, cores=2, seed=12345, sampling_iters=200, data=jdata
)
for i in range(post_sample.chains):
csv_file = post_sample.csv_files[i]
txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
self.assertTrue(os.path.exists(csv_file))
self.assertTrue(os.path.exists(txt_file))
basename = 'bern_save_csvfiles_test'
save_csvfiles(post_sample, datafiles_path, basename) # good
for i in range(post_sample.chains):
csv_file = post_sample.csv_files[i]
self.assertTrue(os.path.exists(csv_file))
def test_sample_big(self):
# construct runset using existing sampler output
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli')
model = Model(exe_file=exe, stan_file=stan)
output = os.path.join(datafiles_path, 'runset-big', 'output_icar_nyc')
args = SamplerArgs(model, chain_ids=[1, 2], output_file=output)
runset = RunSet(chains=2, args=args)
runset.validate_csv_files()
runset.assemble_sample()
sampler_state = [
'lp__',
'accept_stat__',
'stepsize__',
'treedepth__',
'n_leapfrog__',
'divergent__',
'energy__',
]
phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)]
column_names = sampler_state + phis
if not os.path.exists(hpp_file):
raise Exception('syntax error'.format(stan_file))
if platform.system().lower().startswith('win'):
exe_file += '.exe'
if not overwrite and os.path.exists(exe_file):
# print('model is up to date') # notify user or not?
return Model(stan_file, exe_file)
exe_file_path = Path(exe_file).as_posix()
cmd = ['make', 'O={}'.format(opt_lvl), exe_file_path]
print('compiling c++: make args {}'.format(cmd))
try:
do_command(cmd, cmdstan_path())
except Exception:
return Model(stan_file)
return Model(stan_file, exe_file)
do_command(cmd)
if not os.path.exists(hpp_file):
raise Exception('syntax error'.format(stan_file))
if platform.system().lower().startswith('win'):
exe_file += '.exe'
if not overwrite and os.path.exists(exe_file):
# print('model is up to date') # notify user or not?
return Model(stan_file, exe_file)
exe_file_path = Path(exe_file).as_posix()
cmd = ['make', 'O={}'.format(opt_lvl), exe_file_path]
print('compiling c++: make args {}'.format(cmd))
try:
do_command(cmd, cmdstan_path())
except Exception:
return Model(stan_file)
return Model(stan_file, exe_file)