How to use the bambi.priors.Prior function in bambi

To help you get started, we’ve selected a few bambi examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github bambinos / bambi / tests / test_priors.py View on Github external
def test_prior_retrieval():
    config_file = join(dirname(__file__), 'data', 'sample_priors.json')
    pf = PriorFactory(config_file)
    prior = pf.get(dist='asiago')
    assert prior.name == 'Asiago'
    assert isinstance(prior, Prior)
    assert prior.args['hardness'] == 10
    with pytest.raises(KeyError):
        assert prior.args['holes'] == 4
    family = pf.get(family='hard')
    assert isinstance(family, Family)
    assert family.link == 'grate'
    backup = family.prior.args['backup']
    assert isinstance(backup, Prior)
    assert backup.args['flavor'] == 10000
    prior = pf.get(term='yellow')
    assert prior.name == 'Swiss'
github bambinos / bambi / tests / test_priors.py View on Github external
def test_prior_class():
    prior = Prior('CheeseWhiz', holes=0, taste=-10)
    assert prior.name == 'CheeseWhiz'
    assert isinstance(prior.args, dict)
    assert prior.args['taste'] == -10
    prior.update(taste=-100, return_to_store=1)
    assert prior.args['return_to_store'] == 1
github bambinos / bambi / bambi / backends / pymc.py View on Github external
def _expand_args(key, value, label):
            if isinstance(value, Prior):
                label = "%s_%s" % (label, key)
                return self._build_dist(spec, label, value.name, **value.args)
            return value
github bambinos / bambi / bambi / models.py View on Github external
as the sole argument and returns one with the same shape.
        """
        if isinstance(family, str):
            family = self.default_priors.get(family=family)
        self.family = family

        # Override family's link if another is explicitly passed
        if link is not None:
            self.family.link = link

        if prior is None:
            prior = self.family.prior

        # implement default Uniform [0, sd(Y)] prior for residual SD
        if self.family.name == "gaussian":
            prior.update(sd=Prior("Uniform", lower=0, upper=self.clean_data[variable].std()))

        data = kwargs.pop("data", self.clean_data[variable])
        term = Term(variable, data, prior=prior, *args, **kwargs)
        self.y = term
        self.built = False
github bambinos / bambi / bambi / models.py View on Github external
def _prepare_prior(self, prior, _type):
        """
        Parameters
        ----------
        prior : Prior object, or float, or None.
        _type : string
            accepted values are: 'intercept, 'fixed', or 'random'.
        """

        if prior is None and not self.auto_scale:
            prior = self.default_priors.get(term=_type + "_flat")

        if isinstance(prior, Prior):
            prior._auto_scale = False  # pylint: disable=protected-access
        else:
            _scale = prior
            prior = self.default_priors.get(term=_type)
            prior.scale = _scale
            if prior.scale is not None:
                prior._auto_scale = False  # pylint: disable=protected-access
        return prior
github bambinos / bambi / bambi / priors.py View on Github external
def _get_prior(self, spec, **kwargs):

        if isinstance(spec, str):
            spec = re.sub(r"^\#", "", spec)
            return self._get_prior(self.dists[spec])
        elif isinstance(spec, (list, tuple)):
            name, args = spec
            if name.startswith("#"):
                name = re.sub(r"^\#", "", name)
                prior = self._get_prior(self.dists[name])
            else:
                prior = Prior(name, **kwargs)
            args = {k: self._get_prior(v) for (k, v) in args.items()}
            prior.update(**args)
            return prior
        else:
            return spec
github bambinos / bambi / bambi / backends / stan.py View on Github external
if self.mu_cat:
            loops = "for (n in 1:N)\n\t\tyhat[n] = yhat[n] + %s" % " + ".join(self.mu_cat) + ";\n\t"
            self.expressions.append(loops)

        # Add expressions that go in transformed parameter block (they have
        # to come after variable definitions)
        self.transformed_parameters += self.expressions

        # add response variable (y)
        _response_format = self.families[spec.family.name]["format"]
        self.data.append("{} y{};".format(*_response_format))

        # add response distribution parameters other than the location
        # parameter
        for key, value in spec.family.prior.args.items():
            if key != spec.family.parent and isinstance(value, Prior):
                _bounds = _map_dist(value.name, **value.args)[1]
                _param = "real{} {}_{};".format(_bounds, spec.y.name, key)
                self.parameters.append(_param)

        # specify the response distribution
        _response_dist = self.families[spec.family.name]["name"]
        _response_args = "{}(yhat)".format(self.links[spec.family.link])
        _response_args = {spec.family.parent: _response_args}
        for key, value in spec.family.prior.args.items():
            if key != spec.family.parent:
                _response_args[key] = (
                    "{}_{}".format(spec.y.name, key) if isinstance(value, Prior) else str(value)
                )
        _dist = _map_dist(_response_dist, **_response_args)[0]
        self.model.append("y ~ {};".format(_dist))