Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_code_serialization(e) -> None:
expected, actual = e, serde.load_code(serde.dump_code(e))
assert check_equality(expected, actual)
def test_call_and_repr(func) -> None:
global_state = {}
x = evaluate(BASE_RECIPE, length=10, global_state=global_state)
kwargs = dict(foo=42, bar=23)
np.random.seed(0)
ret = func(
x,
field_name="bar",
length=10,
global_state=global_state.copy(),
**kwargs,
)
func_reconstructed = load_code(dump_code(func))
np.random.seed(0)
ret2 = func_reconstructed(
x,
field_name="foo",
length=10,
global_state=global_state.copy(),
**kwargs,
)
print(ret)
np.testing.assert_allclose(ret2, ret)
def assert_serializable(x: transform.Transformation):
t = fqname_for(x.__class__)
y = load_json(dump_json(x))
z = load_code(dump_code(x))
assert dump_json(x) == dump_json(
y
), f"Code serialization for transformer {t} does not work"
assert dump_code(x) == dump_code(
z
), f"JSON serialization for transformer {t} does not work"
lambda x: serde.load_code(serde.dump_code(x)),
],
)
for i in range(5)
]
x_dict = {
i: Foo(
b=random.uniform(0, B),
a=str(random.randint(0, A)),
c=Complex(
x=str(random.uniform(0, C)), y=str(random.uniform(0, C))
),
)
for i in range(6)
}
bar01 = Bar(x_list, input_fields=fields, x_dict=x_dict)
bar02 = load_code(dump_code(bar01))
bar03 = load_json(dump_json(bar02))
def compare_tpes(x, y, z, tpe):
assert tpe == type(x) == type(y) == type(z)
def compare_vals(x, y, z):
assert x == y == z
compare_tpes(bar02.x_list, bar02.x_list, bar03.x_list, tpe=list)
compare_tpes(bar02.x_dict, bar02.x_dict, bar03.x_dict, tpe=dict)
compare_tpes(
bar02.input_fields, bar02.input_fields, bar03.input_fields, tpe=list
)
compare_vals(len(bar02.x_list), len(bar02.x_list), len(bar03.x_list))
compare_vals(len(bar02.x_dict), len(bar02.x_dict), len(bar03.x_dict))
def assert_serializable(x: transform.Transformation):
t = fqname_for(x.__class__)
y = load_json(dump_json(x))
z = load_code(dump_code(x))
assert dump_json(x) == dump_json(
y
), f"Code serialization for transformer {t} does not work"
assert dump_code(x) == dump_code(
z
), f"JSON serialization for transformer {t} does not work"
def run_train_and_test(
env: TrainEnv, forecaster_type: Type[Union[Estimator, Predictor]]
) -> None:
check_gpu_support()
forecaster_fq_name = fqname_for(forecaster_type)
forecaster_version = forecaster_type.__version__
logger.info(f"Using gluonts v{gluonts.__version__}")
logger.info(f"Using forecaster {forecaster_fq_name} v{forecaster_version}")
forecaster = forecaster_type.from_hyperparameters(**env.hyperparameters)
logger.info(
f"The forecaster can be reconstructed with the following expression: "
f"{dump_code(forecaster)}"
)
logger.info(
"Using the following data channels: "
f"{', '.join(name for name in ['train', 'validation', 'test'] if name in env.datasets)}"
)
if isinstance(forecaster, Predictor):
predictor = forecaster
else:
predictor = run_train(
forecaster, env.datasets["train"], env.datasets.get("validation")
)
predictor.serialize(env.path.model)
def metric(metric: str, value: Any) -> None:
"""
Emits a log message with a ``value`` for a specific ``metric``.
Parameters
----------
metric
The name of the metric to be reported.
value
The metric value to be reported.
"""
logger.info(f"gluonts[{metric}]: {dump_code(value)}")