Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _write_endpoint_config_to_yaml(path: Path, data: Dict[Text, Any]) -> Path:
endpoints_path = path / "endpoints.yml"
# write endpoints config to file
io_utils.write_yaml_file(data, endpoints_path)
return endpoints_path
@pytest.fixture(scope="module")
def loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = rasa.utils.io.enable_async_loop_debugging(loop)
yield loop
loop.close()
async def test_dump_and_restore_as_json(default_agent, tmpdir_factory):
trackers = await default_agent.load_data(DEFAULT_STORIES_FILE)
for tracker in trackers:
out_path = tmpdir_factory.mktemp("tracker").join("dumped_tracker.json")
dumped = tracker.current_state(EventVerbosity.AFTER_RESTART)
rasa.utils.io.dump_obj_as_json_to_file(out_path.strpath, dumped)
restored_tracker = restore.load_tracker_from_json(
out_path.strpath, default_agent.domain
)
assert restored_tracker == tracker
os.remove("domain.yml")
run_in_default_project(
"train",
"-c",
"config.yml",
"--data",
"data",
"--out",
"train_models_no_domain",
"--fixed-model-name",
"nlu-model-only",
)
assert os.path.exists("train_models_no_domain")
files = io_utils.list_files("train_models_no_domain")
assert len(files) == 1
trained_model_path = "train_models_no_domain/nlu-model-only.tar.gz"
unpacked = model.unpack_model(trained_model_path)
metadata_path = os.path.join(unpacked, "nlu", "metadata.json")
assert os.path.exists(metadata_path)
def test_get_valid_config(parameters):
import rasa.utils.io
config_path = None
if parameters["config_data"] is not None:
config_path = os.path.join(tempfile.mkdtemp(), "config.yml")
rasa.utils.io.write_yaml_file(parameters["config_data"], config_path)
default_config_path = None
if parameters["default_config"] is not None:
default_config_path = os.path.join(tempfile.mkdtemp(), "default-config.yml")
rasa.utils.io.write_yaml_file(parameters["default_config"], default_config_path)
if parameters["error"]:
with pytest.raises(SystemExit):
_get_valid_config(config_path, parameters["mandatory_keys"])
else:
config_path = _get_valid_config(
config_path, parameters["mandatory_keys"], default_config_path
)
config_data = rasa.utils.io.read_yaml_file(config_path)
for k in parameters["mandatory_keys"]:
assert k in config_data
SessionStarted(),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(
f"/greet{json.dumps(slot_1)}",
{"name": "greet", "confidence": 1.0},
[{"entity": entity, "start": 6, "end": 22, "value": "Core"}],
),
SlotSet(entity, slot_1[entity]),
ActionExecuted("utter_greet"),
BotUttered("hey there Core!"),
ActionExecuted(ACTION_LISTEN_NAME),
ActionExecuted(ACTION_SESSION_START_NAME),
SessionStarted(),
# the initial SlotSet is reapplied after the SessionStarted sequence
SlotSet(entity, slot_1[entity]),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(
f"/greet{json.dumps(slot_2)}",
{"name": "greet", "confidence": 1.0},
[
{
"entity": entity,
"start": 6,
"end": 42,
"value": "post-session start hello",
}
],
),
SlotSet(entity, slot_2[entity]),
ActionExecuted(ACTION_LISTEN_NAME),
]
def test_common_action_prefix():
this = [
ActionExecuted("action_listen"),
ActionExecuted("greet"),
UserUttered("hey"),
ActionExecuted("amazing"),
# until this point they are the same
SlotSet("my_slot", "a"),
ActionExecuted("a"),
ActionExecuted("after_a"),
]
other = [
ActionExecuted("action_listen"),
ActionExecuted("greet"),
UserUttered("hey"),
ActionExecuted("amazing"),
# until this point they are the same
SlotSet("my_slot", "b"),
ActionExecuted("b"),
ActionExecuted("after_b"),
]
num_common = visualization._length_of_common_action_prefix(this, other)
assert num_common == 3
async def test_action_session_start_without_slots(
default_channel: CollectingOutputChannel,
template_nlg: TemplatedNaturalLanguageGenerator,
template_sender_tracker: DialogueStateTracker,
default_domain: Domain,
):
events = await ActionSessionStart().run(
default_channel, template_nlg, template_sender_tracker, default_domain
)
assert events == [SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME)]
async def test_successful_rephrasing(
self, default_channel, default_nlg, default_domain
):
events = [
ActionExecuted(ACTION_LISTEN_NAME),
user_uttered("greet", 0.2),
ActionExecuted(ACTION_DEFAULT_ASK_AFFIRMATION_NAME),
ActionExecuted(ACTION_LISTEN_NAME),
user_uttered("deny", 1),
ActionExecuted(ACTION_DEFAULT_ASK_REPHRASE_NAME),
ActionExecuted(ACTION_LISTEN_NAME),
user_uttered("bye", 1),
]
tracker = await self._get_tracker_after_reverts(
events, default_channel, default_nlg, default_domain
)
assert "bye" == tracker.latest_message.parse_data["intent"]["name"]
assert tracker.export_stories() == "## sender\n* bye\n"
def test_events_metadata():
# It should be possible to attach arbitrary metadata to any event and then
# retrieve it after getting the tracker dict representation.
events = [
ActionExecuted("one", metadata={"one": 1}),
user_uttered("two", 1, metadata={"two": 2}),
ActionExecuted(ACTION_LISTEN_NAME, metadata={"three": 3}),
]
events = get_tracker(events).current_state(EventVerbosity.ALL)["events"]
assert events[0]["metadata"] == {"one": 1}
assert events[1]["metadata"] == {"two": 2}
assert events[2]["metadata"] == {"three": 3}