How to use the rasa.core.agent.Agent function in rasa

To help you get started, weā€™ve selected a few rasa examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github RasaHQ / rasa / tests / core / test_training.py View on Github external
async def test_training_script_without_max_history_set(tmpdir):
    await train(
        DEFAULT_DOMAIN_PATH_WITH_SLOTS,
        DEFAULT_STORIES_FILE,
        tmpdir.strpath,
        interpreter=RegexInterpreter(),
        policy_config="data/test_config/no_max_hist_config.yml",
        kwargs={},
    )

    agent = Agent.load(tmpdir.strpath)
    for policy in agent.policy_ensemble.policies:
        if hasattr(policy.featurizer, "max_history"):
            if type(policy) == FormPolicy:
                assert policy.featurizer.max_history == 2
            else:
                assert (
                    policy.featurizer.max_history
                    == policy.featurizer.MAX_HISTORY_DEFAULT
                )
github botfront / rasa-for-botfront / tests / core / test_agent.py View on Github external
def test_two_stage_fallback_without_deny_suggestion(domain, policy_config):
    with pytest.raises(InvalidDomain) as execinfo:
        Agent(
            domain=Domain.from_dict(domain),
            policies=PolicyEnsemble.from_dict(policy_config),
        )
    assert "The intent 'out_of_scope' must be present" in str(execinfo.value)
github RasaHQ / rasa / tests / core / test_agent.py View on Github external
async def test_training_data_is_reproducible(tmpdir, default_domain):
    training_data_file = "examples/moodbot/data/stories.md"
    agent = Agent(
        "examples/moodbot/domain.yml", policies=[AugmentedMemoizationPolicy()]
    )

    training_data = await agent.load_data(training_data_file)
    # make another copy of training data
    same_training_data = await agent.load_data(training_data_file)

    # test if both datasets are identical (including in the same order)
    for i, x in enumerate(training_data):
        assert str(x.as_dialogue()) == str(same_training_data[i].as_dialogue())
github RasaHQ / rasa / tests / core / test_agent.py View on Github external
async def test_agent_train(tmpdir, default_domain):
    training_data_file = "examples/moodbot/data/stories.md"
    agent = Agent(
        "examples/moodbot/domain.yml", policies=[AugmentedMemoizationPolicy()]
    )

    training_data = await agent.load_data(training_data_file)

    agent.train(training_data)
    agent.persist(tmpdir.strpath)

    loaded = Agent.load(tmpdir.strpath)

    # test domain
    assert loaded.domain.action_names == agent.domain.action_names
    assert loaded.domain.intents == agent.domain.intents
    assert loaded.domain.entities == agent.domain.entities
    assert loaded.domain.templates == agent.domain.templates
    assert [s.name for s in loaded.domain.slots] == [s.name for s in agent.domain.slots]
github RasaHQ / rasa / tests / core / test_agent.py View on Github external
async def test_agent_train(tmpdir, default_domain):
    training_data_file = "examples/moodbot/data/stories.md"
    agent = Agent(
        "examples/moodbot/domain.yml", policies=[AugmentedMemoizationPolicy()]
    )

    training_data = await agent.load_data(training_data_file)

    agent.train(training_data)
    agent.persist(tmpdir.strpath)

    loaded = Agent.load(tmpdir.strpath)

    # test domain
    assert loaded.domain.action_names == agent.domain.action_names
    assert loaded.domain.intents == agent.domain.intents
    assert loaded.domain.entities == agent.domain.entities
    assert loaded.domain.templates == agent.domain.templates
    assert [s.name for s in loaded.domain.slots] == [s.name for s in agent.domain.slots]
github RasaHQ / rasa / rasa / core / train.py View on Github external
exclusion_percentage: int = None,
    kwargs: Optional[Dict] = None,
):
    from rasa.core.agent import Agent
    from rasa.core import config, utils
    from rasa.core.run import AvailableEndpoints

    if not endpoints:
        endpoints = AvailableEndpoints()

    if not kwargs:
        kwargs = {}

    policies = config.load(policy_config)

    agent = Agent(
        domain_file,
        generator=endpoints.nlg,
        action_endpoint=endpoints.action,
        interpreter=interpreter,
        policies=policies,
    )

    data_load_args, kwargs = utils.extract_args(
        kwargs,
        {
            "use_story_concatenation",
            "unique_last_num_states",
            "augmentation_factor",
            "remove_duplicates",
            "debug_plots",
        },
github RasaHQ / rasa / rasa / core / agent.py View on Github external
interpreter: Optional[NaturalLanguageInterpreter] = None,
        generator: Union[EndpointConfig, NaturalLanguageGenerator] = None,
        tracker_store: Optional[TrackerStore] = None,
        lock_store: Optional[LockStore] = None,
        action_endpoint: Optional[EndpointConfig] = None,
        model_server: Optional[EndpointConfig] = None,
    ) -> Optional["Agent"]:
        from rasa.nlu.persistor import get_persistor

        persistor = get_persistor(remote_storage)

        if persistor is not None:
            target_path = tempfile.mkdtemp()
            persistor.retrieve(model_name, target_path)

            return Agent.load(
                target_path,
                interpreter=interpreter,
                generator=generator,
                tracker_store=tracker_store,
                lock_store=lock_store,
                action_endpoint=action_endpoint,
                model_server=model_server,
                remote_storage=remote_storage,
            )

        return None
github RasaHQ / rasa / rasa / core / visualize.py View on Github external
max_history: int,
):
    from rasa.core.agent import Agent
    from rasa.core import config

    try:
        policies = config.load(config_path)
    except ValueError as e:
        print_error(
            "Could not load config due to: '{}'. To specify a valid config file use "
            "the '--config' argument.".format(e)
        )
        return

    try:
        agent = Agent(domain=domain_path, policies=policies)
    except InvalidDomain as e:
        print_error(
            "Could not load domain due to: '{}'. To specify a valid domain path use "
            "the '--domain' argument.".format(e)
        )
        return

    # this is optional, only needed if the `/greet` type of
    # messages in the stories should be replaced with actual
    # messages (e.g. `hello`)
    if nlu_data_path is not None:
        from rasa.nlu.training_data import load_data

        nlu_data_path = load_data(nlu_data_path)
    else:
        nlu_data_path = None
github RasaHQ / rasa / rasa / core / agent.py View on Github external
model_server: Optional[EndpointConfig] = None,
        remote_storage: Optional[Text] = None,
    ) -> "Agent":
        if os.path.isfile(model_path):
            model_archive = model_path
        else:
            model_archive = get_latest_model(model_path)

        if model_archive is None:
            warnings.warn(f"Could not load local model in '{model_path}'.")
            return Agent()

        working_directory = tempfile.mkdtemp()
        unpacked_model = unpack_model(model_archive, working_directory)

        return Agent.load(
            unpacked_model,
            interpreter=interpreter,
            generator=generator,
            tracker_store=tracker_store,
            lock_store=lock_store,
            action_endpoint=action_endpoint,
            model_server=model_server,
            remote_storage=remote_storage,
            path_to_model_archive=model_archive,
        )
github RasaHQ / rasa / rasa / core / agent.py View on Github external
async def load_agent(
    model_path: Optional[Text] = None,
    model_server: Optional[EndpointConfig] = None,
    remote_storage: Optional[Text] = None,
    interpreter: Optional[NaturalLanguageInterpreter] = None,
    generator: Union[EndpointConfig, NaturalLanguageGenerator] = None,
    tracker_store: Optional[TrackerStore] = None,
    lock_store: Optional[LockStore] = None,
    action_endpoint: Optional[EndpointConfig] = None,
):
    try:
        if model_server is not None:
            return await load_from_server(
                Agent(
                    interpreter=interpreter,
                    generator=generator,
                    tracker_store=tracker_store,
                    lock_store=lock_store,
                    action_endpoint=action_endpoint,
                    model_server=model_server,
                    remote_storage=remote_storage,
                ),
                model_server,
            )

        elif remote_storage is not None:
            return Agent.load_from_remote_storage(
                remote_storage,
                model_path,
                interpreter=interpreter,