Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if self.dims is None:
dims = {}
else:
dims = self.dims
observed_data = {}
for name, vals in self.observations.items():
vals = utils.one_de(vals)
val_dims = dims.get(name)
val_dims, coords = generate_dims_coords(
vals.shape, name, dims=val_dims, coords=self.coords
)
# filter coords based on the dims
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.numpyro))
ndim = sampler.chain.shape[-1]
num_args = len(sampler.args)
elif hasattr(sampler, "log_prob_fn"):
ndim = sampler.get_chain().shape[-1]
num_args = len(sampler.log_prob_fn.args)
else:
ndim = sampler.get_chain().shape[-1]
num_args = 0 # emcee only stores the posterior samples
if slices is None:
slices = utils.arange(ndim)
num_vars = ndim
else:
num_vars = len(slices)
indexs = utils.arange(ndim)
slicing_try = np.concatenate([utils.one_de(indexs[idx]) for idx in slices])
if len(set(slicing_try)) != ndim:
warnings.warn(
"Check slices: Not all parameters in chain captured. "
"{} are present, and {} have been captured.".format(ndim, len(slicing_try)),
SyntaxWarning,
)
if len(slicing_try) != len(set(slicing_try)):
warnings.warn(
"Overlapping slices. Check the index present: {}".format(slicing_try), SyntaxWarning
)
if var_names is None:
var_names = ["var_{}".format(idx) for idx in range(num_vars)]
if arg_names is None:
arg_names = ["arg_{}".format(idx) for idx in range(num_args)]
def log_likelihood_vals_point(point):
"""Compute log likelihood for each observed point."""
log_like_vals = []
for var, log_like in cached:
log_like_val = utils.one_de(log_like(point))
if var.missing_values:
log_like_val = log_like_val[~var.observations.mask]
log_like_vals.append(log_like_val)
return np.concatenate(log_like_vals)
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
observed_data_raw = _read_data(self.observed_data)
variables = self.observed_data_var
if isinstance(variables, str):
variables = [variables]
observed_data = {}
for key, vals in observed_data_raw.items():
if variables is not None and key not in variables:
continue
vals = utils.one_de(vals)
val_dims = self.dims.get(key)
val_dims, coords = generate_dims_coords(
vals.shape, key, dims=val_dims, coords=self.coords
)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data)
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if self.dims is None:
dims = {}
else:
dims = self.dims
observed_data = {}
for name, vals in self.observations.items():
if hasattr(vals, "get_value"):
vals = vals.get_value()
vals = utils.one_de(vals)
val_dims = dims.get(name)
val_dims, coords = generate_dims_coords(
vals.shape, name, dims=val_dims, coords=self.coords
)
# filter coords based on the dims
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.pymc3))