Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_prior_factory_init_from_config():
config_file = join(dirname(__file__), 'data', 'sample_priors.json')
pf = PriorFactory(config_file)
for d in ['dists', 'terms', 'families']:
assert hasattr(pf, d)
assert isinstance(getattr(pf, d), dict)
config_dict = json.load(open(config_file, 'r'))
pf = PriorFactory(config_dict)
for d in ['dists', 'terms', 'families']:
assert hasattr(pf, d)
assert isinstance(getattr(pf, d), dict)
assert 'feta' in pf.dists
assert 'hard' in pf.families
assert 'yellow' in pf.terms
def test_prior_factory_init_from_default_config():
pf = PriorFactory()
for d in ['dists', 'terms', 'families']:
assert hasattr(pf, d)
assert isinstance(getattr(pf, d), dict)
assert 'normal' in pf.dists
assert 'fixed' in pf.terms
assert 'gaussian' in pf.families
def test_prior_retrieval():
config_file = join(dirname(__file__), 'data', 'sample_priors.json')
pf = PriorFactory(config_file)
prior = pf.get(dist='asiago')
assert prior.name == 'Asiago'
assert isinstance(prior, Prior)
assert prior.args['hardness'] == 10
with pytest.raises(KeyError):
assert prior.args['holes'] == 4
family = pf.get(family='hard')
assert isinstance(family, Family)
assert family.link == 'grate'
backup = family.prior.args['backup']
assert isinstance(backup, Prior)
assert backup.args['flavor'] == 10000
prior = pf.get(term='yellow')
assert prior.name == 'Swiss'
def test_prior_factory_init_from_config():
config_file = join(dirname(__file__), 'data', 'sample_priors.json')
pf = PriorFactory(config_file)
for d in ['dists', 'terms', 'families']:
assert hasattr(pf, d)
assert isinstance(getattr(pf, d), dict)
config_dict = json.load(open(config_file, 'r'))
pf = PriorFactory(config_dict)
for d in ['dists', 'terms', 'families']:
assert hasattr(pf, d)
assert isinstance(getattr(pf, d), dict)
assert 'feta' in pf.dists
assert 'hard' in pf.families
assert 'yellow' in pf.terms
def __init__(
self,
data=None,
default_priors=None,
auto_scale=True,
dropna=False,
taylor=None,
noncentered=True,
):
if isinstance(data, str):
data = pd.read_csv(data, sep=None, engine="python")
self.default_priors = PriorFactory(default_priors)
obj_cols = data.select_dtypes(["object"]).columns
data[obj_cols] = data[obj_cols].apply(lambda x: x.astype("category"))
self.data = data
# Some random effects stuff later requires us to make guesses about
# column groupings into terms based on patsy's naming scheme.
if re.search(r"[\[\]]+", "".join(data.columns)):
warnings.warn(
"At least one of the column names in the specified "
"dataset contain square brackets ('[' or ']')."
"This may cause unexpected behavior if you specify "
"models with random effects. You are encouraged to "
"rename your columns to avoid square brackets."
)
self.reset()