Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_policy_priority():
domain = Domain.load("data/test_domains/default.yml")
tracker = DialogueStateTracker.from_events("test", [UserUttered("hi")], [])
priority_1 = ConstantPolicy(priority=1, predict_index=0)
priority_2 = ConstantPolicy(priority=2, predict_index=1)
policy_ensemble_0 = SimplePolicyEnsemble([priority_1, priority_2])
policy_ensemble_1 = SimplePolicyEnsemble([priority_2, priority_1])
priority_2_result = priority_2.predict_action_probabilities(tracker, domain)
i = 1 # index of priority_2 in ensemble_0
result, best_policy = policy_ensemble_0.probabilities_using_best_policy(
tracker, domain
)
assert best_policy == "policy_{}_{}".format(i, type(priority_2).__name__)
assert result == priority_2_result
async def test_update_tracker_session_with_slots(
default_channel: CollectingOutputChannel,
default_processor: MessageProcessor,
monkeypatch: MonkeyPatch,
):
sender_id = uuid.uuid4().hex
tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)
# apply a user uttered and five slots
user_event = UserUttered("some utterance")
tracker.update(user_event)
slot_set_events = [SlotSet(f"slot key {i}", f"test value {i}") for i in range(5)]
for event in slot_set_events:
tracker.update(event)
# patch `_has_session_expired()` so the `_update_tracker_session()` call actually
# does something
monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True)
await default_processor._update_tracker_session(tracker, default_channel)
# the save is not called in _update_tracker_session()
default_processor._save_tracker(tracker)
@pytest.mark.parametrize(
"one_event",
[
UserUttered("/greet", {"name": "greet", "confidence": 1.0}, []),
UserUttered(metadata={"type": "text"}),
UserUttered(metadata=None),
UserUttered(text="hi", message_id="1", metadata={"type": "text"}),
SlotSet("name", "rasa"),
Restarted(),
AllSlotsReset(),
ConversationPaused(),
ConversationResumed(),
StoryExported(),
ActionReverted(),
UserUtteranceReverted(),
ActionExecuted("my_action"),
ActionExecuted("my_action", "policy_1_KerasPolicy", 0.8),
FollowupAction("my_action"),
BotUttered("my_text", {"my_data": 1}),
AgentUttered("my_text", "my_data"),
ReminderScheduled("my_action", datetime.now()),
ReminderScheduled("my_action", datetime.now(pytz.timezone("US/Central"))),
],
tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)
tracker.update(UserUttered("test"))
tracker.update(ActionExecuted("action_schedule_reminder"))
tracker.update(reminder)
default_processor.tracker_store.save(tracker)
await default_processor.handle_reminder(
reminder, sender_id, default_channel, default_processor.nlg
)
# retrieve the updated tracker
t = default_processor.tracker_store.retrieve(sender_id)
assert t.events[-5] == UserUttered("test")
assert t.events[-4] == ActionExecuted("action_schedule_reminder")
assert isinstance(t.events[-3], ReminderScheduled)
assert t.events[-2] == UserUttered(
f"{EXTERNAL_MESSAGE_PREFIX}remind", intent={"name": "remind", IS_EXTERNAL: True}
)
assert t.events[-1] == ActionExecuted("action_listen")
UserUttered("/greet", {"name": "greet", "confidence": 1.0}, []),
UserUttered("/goodbye", {"name": "goodbye", "confidence": 1.0}, []),
),
(SlotSet("my_slot", "value"), SlotSet("my__other_slot", "value")),
(Restarted(), None),
(AllSlotsReset(), None),
(ConversationPaused(), None),
(ConversationResumed(), None),
(StoryExported(), None),
(ActionReverted(), None),
(UserUtteranceReverted(), None),
(ActionExecuted("my_action"), ActionExecuted("my_other_action")),
(FollowupAction("my_action"), FollowupAction("my_other_action")),
(
BotUttered("my_text", {"my_data": 1}),
BotUttered("my_other_test", {"my_other_data": 1}),
),
def _reset(self) -> None:
"""Reset tracker to initial state - doesn't delete events though!."""
self._reset_slots()
self._paused = False
self.latest_action_name = None
self.latest_message = UserUttered.empty()
self.latest_bot_utterance = BotUttered.empty()
self.followup_action = ACTION_LISTEN_NAME
self.active_form = {}
def explicit_events(
self, domain: Domain, should_append_final_listen: bool = True
) -> List[Event]:
"""Returns events contained in the story step including implicit events.
Not all events are always listed in the story dsl. This
includes listen actions as well as implicitly
set slots. This functions makes these events explicit and
returns them with the rest of the steps events.
"""
events = []
for e in self.events:
if isinstance(e, UserUttered):
self._add_action_listen(events)
events.append(e)
events.extend(domain.slots_for_entities(e.entities))
else:
events.append(e)
if not self.end_checkpoints and should_append_final_listen:
self._add_action_listen(events)
return events
if (
reminder_event.kill_on_user_message
and self._has_message_after_reminder(tracker, reminder_event)
or not self._is_reminder_still_valid(tracker, reminder_event)
):
logger.debug(
"Canceled reminder because it is outdated. "
"(event: {} id: {})".format(
reminder_event.action_name, reminder_event.name
)
)
else:
# necessary for proper featurization, otherwise the previous
# unrelated message would influence featurization
tracker.update(UserUttered.empty())
action = self._get_action(reminder_event.action_name)
should_continue = await self._run_action(
action, tracker, output_channel, nlg
)
if should_continue:
user_msg = UserMessage(None, output_channel, sender_id)
await self._predict_and_execute_next_action(user_msg, tracker)
# save tracker state to continue conversation from this state
self._save_tracker(tracker)
async def _handle_message_with_tracker(
self, message: UserMessage, tracker: DialogueStateTracker
) -> None:
if message.parse_data:
parse_data = message.parse_data
else:
parse_data = await self._parse_message(message, tracker)
# don't ever directly mutate the tracker
# - instead pass its events to log
tracker.update(
UserUttered(
message.text,
parse_data["intent"],
parse_data["entities"],
parse_data,
input_channel=message.input_channel,
message_id=message.message_id,
metadata=message.metadata,
),
self.domain,
)
if parse_data["entities"]:
self._log_slots(tracker)
logger.debug(
"Logged UserUtterance - "
def explicit_events(
self, domain: Domain, should_append_final_listen: bool = True
) -> List[Event]:
"""Returns events contained in the story step
including implicit events.
Not all events are always listed in the story dsl. This
includes listen actions as well as implicitly
set slots. This functions makes these events explicit and
returns them with the rest of the steps events."""
events = []
for e in self.events:
if isinstance(e, UserUttered):
self._add_action_listen(events)
events.append(e)
events.extend(domain.slots_for_entities(e.entities))
else:
events.append(e)
if not self.end_checkpoints and should_append_final_listen:
self._add_action_listen(events)
return events