Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test___init__(self, load_primitive_mock, import_object_mock, set_hps_mock):
load_primitive_mock.return_value = {
'name': 'a_primitive_name',
'primitive': 'a_primitive_name',
'produce': {
'args': [
{
'name': 'argument'
}
],
'output': [
]
}
}
mlblock = MLBlock('a_primitive_name', argument='value')
assert mlblock.metadata == {
'name': 'a_primitive_name',
'primitive': 'a_primitive_name',
'produce': {
'args': [
{
'name': 'argument'
}
],
'output': [
]
}
}
assert mlblock.name == 'a_primitive_name'
assert mlblock.primitive == import_object_mock.return_value
def primitive(a_list_param):
a_list_param.append('b')
io_mock.return_value = primitive
lp_mock.return_value = {
'name': 'a_primitive',
'primitive': 'a_primitive',
'produce': {
'args': [],
'output': []
}
}
mlblock = MLBlock('a_primitive')
hyperparameters = {
'a_list_param': ['a']
}
mlblock._hyperparameters = hyperparameters
mlblock.produce()
assert 'b' not in hyperparameters['a_list_param']
def test_get_tunable_hyperparameters(self, load_primitive_mock, import_object_mock):
"""get_tunable_hyperparameters has to return a copy of the _tunables attribute."""
load_primitive_mock.return_value = {
'name': 'a_primitive_name',
'primitive': 'a_primitive_name',
'produce': {
'args': [],
'output': []
}
}
mlblock = MLBlock('given_primitive_name')
tunable = dict()
mlblock._tunable = tunable
returned = mlblock.get_tunable_hyperparameters()
assert returned == tunable
assert returned is not tunable
def test___str__(self, load_primitive_mock, import_object_mock):
load_primitive_mock.return_value = {
'name': 'a_primitive_name',
'primitive': 'a_primitive_name',
'produce': {
'args': [],
'output': []
}
}
mlblock = MLBlock('a_primitive_name')
assert str(mlblock) == 'MLBlock - a_primitive_name'
fit_method_name = self.metadata['fit']
produce_method_name = self.metadata['produce']
build_method = self.build_mlblock_model
def fit(self, *args, **kwargs):
# Only fit if fit method provided.
if fit_method_name:
getattr(self.model, fit_method_name)(*args, **kwargs)
instance.fit = fit.__get__(instance, MLBlock)
def produce(self, *args, **kwargs):
# Every MLBlock needs a produce method.
return getattr(self.model, produce_method_name)(*args, **kwargs)
instance.produce = produce.__get__(instance, MLBlock)
def update_model(self, fixed_hyperparams, tunable_hyperparams):
self.model = build_method(fixed_hyperparams, tunable_hyperparams)
instance.update_model = update_model.__get__(instance, MLBlock)
primitive_name = primitive
else:
primitive_name = primitive['name']
try:
block_names_count.update([primitive_name])
block_count = block_names_count[primitive_name]
block_name = '{}#{}'.format(primitive_name, block_count)
block_params = self.init_params.get(block_name, dict())
if not block_params:
block_params = self.init_params.get(primitive_name, dict())
if block_params and block_count > 1:
LOGGER.warning(("Non-numbered init_params are being used "
"for more than one block %s."), primitive_name)
block = MLBlock(primitive, **block_params)
blocks[block_name] = block
except Exception:
LOGGER.exception("Exception caught building MLBlock %s", primitive)
raise
return blocks
Args:
blocks: A list of MLBlocks composing this pipeline. MLBlocks
can be either MLBlock instances or primitive names to
load from the configuration JSON files.
"""
blocks = blocks or self.BLOCKS
init_params = self.get_nested(init_params)
if not blocks:
raise ValueError("At least one block is needed")
self.blocks = OrderedDict()
for block in blocks:
if not isinstance(block, MLBlock):
block = self._load_block(block, init_params)
self.blocks[block.name] = block
def build_mlblock(self):
block_name = self.metadata['name']
fixed_hyperparams = self.metadata['fixed_hyperparameters']
fixed_hyperparams.update(self.init_params)
tunable_hyperparams = self.get_mlhyperparams(block_name)
model = self.build_mlblock_model(fixed_hyperparams,
tunable_hyperparams)
instance = MLBlock(
name=block_name,
model=model,
fixed_hyperparams=fixed_hyperparams,
tunable_hyperparams=tunable_hyperparams
)
self.replace_instance_methods(instance)
return instance