Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
async def test_training_script_without_max_history_set(tmpdir):
await train(
DEFAULT_DOMAIN_PATH_WITH_SLOTS,
DEFAULT_STORIES_FILE,
tmpdir.strpath,
interpreter=RegexInterpreter(),
policy_config="data/test_config/no_max_hist_config.yml",
kwargs={},
)
agent = Agent.load(tmpdir.strpath)
for policy in agent.policy_ensemble.policies:
if hasattr(policy.featurizer, "max_history"):
if type(policy) == FormPolicy:
assert policy.featurizer.max_history == 2
else:
assert (
policy.featurizer.max_history
== policy.featurizer.MAX_HISTORY_DEFAULT
)
def test_two_stage_fallback_without_deny_suggestion(domain, policy_config):
with pytest.raises(InvalidDomain) as execinfo:
Agent(
domain=Domain.from_dict(domain),
policies=PolicyEnsemble.from_dict(policy_config),
)
assert "The intent 'out_of_scope' must be present" in str(execinfo.value)
async def test_training_data_is_reproducible(tmpdir, default_domain):
training_data_file = "examples/moodbot/data/stories.md"
agent = Agent(
"examples/moodbot/domain.yml", policies=[AugmentedMemoizationPolicy()]
)
training_data = await agent.load_data(training_data_file)
# make another copy of training data
same_training_data = await agent.load_data(training_data_file)
# test if both datasets are identical (including in the same order)
for i, x in enumerate(training_data):
assert str(x.as_dialogue()) == str(same_training_data[i].as_dialogue())
async def test_agent_train(tmpdir, default_domain):
training_data_file = "examples/moodbot/data/stories.md"
agent = Agent(
"examples/moodbot/domain.yml", policies=[AugmentedMemoizationPolicy()]
)
training_data = await agent.load_data(training_data_file)
agent.train(training_data)
agent.persist(tmpdir.strpath)
loaded = Agent.load(tmpdir.strpath)
# test domain
assert loaded.domain.action_names == agent.domain.action_names
assert loaded.domain.intents == agent.domain.intents
assert loaded.domain.entities == agent.domain.entities
assert loaded.domain.templates == agent.domain.templates
assert [s.name for s in loaded.domain.slots] == [s.name for s in agent.domain.slots]
async def test_agent_train(tmpdir, default_domain):
training_data_file = "examples/moodbot/data/stories.md"
agent = Agent(
"examples/moodbot/domain.yml", policies=[AugmentedMemoizationPolicy()]
)
training_data = await agent.load_data(training_data_file)
agent.train(training_data)
agent.persist(tmpdir.strpath)
loaded = Agent.load(tmpdir.strpath)
# test domain
assert loaded.domain.action_names == agent.domain.action_names
assert loaded.domain.intents == agent.domain.intents
assert loaded.domain.entities == agent.domain.entities
assert loaded.domain.templates == agent.domain.templates
assert [s.name for s in loaded.domain.slots] == [s.name for s in agent.domain.slots]
exclusion_percentage: int = None,
kwargs: Optional[Dict] = None,
):
from rasa.core.agent import Agent
from rasa.core import config, utils
from rasa.core.run import AvailableEndpoints
if not endpoints:
endpoints = AvailableEndpoints()
if not kwargs:
kwargs = {}
policies = config.load(policy_config)
agent = Agent(
domain_file,
generator=endpoints.nlg,
action_endpoint=endpoints.action,
interpreter=interpreter,
policies=policies,
)
data_load_args, kwargs = utils.extract_args(
kwargs,
{
"use_story_concatenation",
"unique_last_num_states",
"augmentation_factor",
"remove_duplicates",
"debug_plots",
},
interpreter: Optional[NaturalLanguageInterpreter] = None,
generator: Union[EndpointConfig, NaturalLanguageGenerator] = None,
tracker_store: Optional[TrackerStore] = None,
lock_store: Optional[LockStore] = None,
action_endpoint: Optional[EndpointConfig] = None,
model_server: Optional[EndpointConfig] = None,
) -> Optional["Agent"]:
from rasa.nlu.persistor import get_persistor
persistor = get_persistor(remote_storage)
if persistor is not None:
target_path = tempfile.mkdtemp()
persistor.retrieve(model_name, target_path)
return Agent.load(
target_path,
interpreter=interpreter,
generator=generator,
tracker_store=tracker_store,
lock_store=lock_store,
action_endpoint=action_endpoint,
model_server=model_server,
remote_storage=remote_storage,
)
return None
max_history: int,
):
from rasa.core.agent import Agent
from rasa.core import config
try:
policies = config.load(config_path)
except ValueError as e:
print_error(
"Could not load config due to: '{}'. To specify a valid config file use "
"the '--config' argument.".format(e)
)
return
try:
agent = Agent(domain=domain_path, policies=policies)
except InvalidDomain as e:
print_error(
"Could not load domain due to: '{}'. To specify a valid domain path use "
"the '--domain' argument.".format(e)
)
return
# this is optional, only needed if the `/greet` type of
# messages in the stories should be replaced with actual
# messages (e.g. `hello`)
if nlu_data_path is not None:
from rasa.nlu.training_data import load_data
nlu_data_path = load_data(nlu_data_path)
else:
nlu_data_path = None
model_server: Optional[EndpointConfig] = None,
remote_storage: Optional[Text] = None,
) -> "Agent":
if os.path.isfile(model_path):
model_archive = model_path
else:
model_archive = get_latest_model(model_path)
if model_archive is None:
warnings.warn(f"Could not load local model in '{model_path}'.")
return Agent()
working_directory = tempfile.mkdtemp()
unpacked_model = unpack_model(model_archive, working_directory)
return Agent.load(
unpacked_model,
interpreter=interpreter,
generator=generator,
tracker_store=tracker_store,
lock_store=lock_store,
action_endpoint=action_endpoint,
model_server=model_server,
remote_storage=remote_storage,
path_to_model_archive=model_archive,
)
async def load_agent(
model_path: Optional[Text] = None,
model_server: Optional[EndpointConfig] = None,
remote_storage: Optional[Text] = None,
interpreter: Optional[NaturalLanguageInterpreter] = None,
generator: Union[EndpointConfig, NaturalLanguageGenerator] = None,
tracker_store: Optional[TrackerStore] = None,
lock_store: Optional[LockStore] = None,
action_endpoint: Optional[EndpointConfig] = None,
):
try:
if model_server is not None:
return await load_from_server(
Agent(
interpreter=interpreter,
generator=generator,
tracker_store=tracker_store,
lock_store=lock_store,
action_endpoint=action_endpoint,
model_server=model_server,
remote_storage=remote_storage,
),
model_server,
)
elif remote_storage is not None:
return Agent.load_from_remote_storage(
remote_storage,
model_path,
interpreter=interpreter,