Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_file_broker_properly_logs_newlines(tmpdir):
log_file_path = tmpdir.join("events.log").strpath
actual = EventBroker.create(
EndpointConfig(**{"type": "file", "path": log_file_path})
)
event_with_newline = UserUttered("hello \n there")
actual.publish(event_with_newline.as_dict())
# reading the events from the file one event per line
recovered = []
with open(log_file_path, "r") as log_file:
for line in log_file:
recovered.append(Event.from_parameters(json.loads(line)))
assert recovered == [event_with_newline]
async def test_http_interpreter(endpoint_url, joined_url):
with aioresponses() as mocked:
mocked.post(joined_url)
endpoint = EndpointConfig(endpoint_url)
interpreter = RasaNLUHttpInterpreter(endpoint_config=endpoint)
await interpreter.parse(text="message_text", message_id="message_id")
r = latest_request(mocked, "POST", joined_url)
query = json_of_latest_request(r)
response = {"text": "message_text", "token": None, "message_id": "message_id"}
assert query == response
async def test_parsing_with_tracker():
tracker = DialogueStateTracker.from_dict("1", [], [Slot("requested_language")])
# we'll expect this value 'en' to be part of the result from the interpreter
tracker._set_slot("requested_language", "en")
endpoint = EndpointConfig("https://interpreter.com")
with aioresponses() as mocked:
mocked.post("https://interpreter.com/parse", repeat=True, status=200)
# mock the parse function with the one defined for this test
with patch.object(RasaNLUHttpInterpreter, "parse", mocked_parse):
interpreter = RasaNLUHttpInterpreter(endpoint=endpoint)
agent = Agent(None, None, interpreter)
result = await agent.parse_message_using_nlu_interpreter("lunch?", tracker)
assert result["requested_language"] == "en"
if endpoints.event_broker and not _is_correct_event_broker(endpoints.event_broker):
cli_utils.print_error(
"Rasa X currently only supports a SQLite event broker with path '{}' "
"when running locally. You can deploy Rasa X with Docker "
"(https://rasa.com/docs/rasa-x/deploy/) if you want to use "
"other event broker configurations.".format(DEFAULT_EVENTS_DB)
)
overwrite_existing_event_broker = questionary.confirm(
"Do you want to continue with the default SQLite event broker?"
).ask()
if not overwrite_existing_event_broker:
exit(0)
if not endpoints.tracker_store or overwrite_existing_event_broker:
endpoints.event_broker = EndpointConfig(type="sql", db=DEFAULT_EVENTS_DB)
if path:
persisted_path = trainer.persist(path, persistor, fixed_model_name)
else:
persisted_path = None
return trainer, interpreter, persisted_path
if __name__ == '__main__':
cmdline_args = create_argument_parser().parse_args()
utils.configure_colored_logging(cmdline_args.loglevel)
if cmdline_args.url:
data_endpoint = EndpointConfig(cmdline_args.url)
else:
data_endpoint = read_endpoints(cmdline_args.endpoints).data
train(cmdline_args.config,
cmdline_args.data,
cmdline_args.path,
cmdline_args.project,
cmdline_args.fixed_model_name,
cmdline_args.storage,
training_data_endpoint=data_endpoint,
num_threads=cmdline_args.num_threads)
logger.info("Finished training")
def copy(self) -> "EndpointConfig":
return EndpointConfig(
self.url,
self.params,
self.headers,
self.basic_auth,
self.token,
self.token_name,
**self.kwargs,
)
def _overwrite_endpoints_for_local_x(
endpoints: AvailableEndpoints, rasa_x_token: Text, rasa_x_url: Text
):
from rasa.utils.endpoints import EndpointConfig
import questionary
endpoints.model = EndpointConfig(
"{}/projects/default/models/tags/production".format(rasa_x_url),
token=rasa_x_token,
wait_time_between_pulls=2,
)
overwrite_existing_event_broker = False
if endpoints.event_broker and not _is_correct_event_broker(endpoints.event_broker):
cli_utils.print_error(
"Rasa X currently only supports a SQLite event broker with path '{}' "
"when running locally. You can deploy Rasa X with Docker "
"(https://rasa.com/docs/rasa-x/deploy/) if you want to use "
"other event broker configurations.".format(DEFAULT_EVENTS_DB)
)
overwrite_existing_event_broker = questionary.confirm(
"Do you want to continue with the default SQLite event broker?"
).ask()
)
model_endpoint = model_endpoint or EndpointConfig()
# Checking if endpoint.yml has existing url, if so give
# warning we are overwriting the endpoint.yml file.
custom_url = model_endpoint.url
if custom_url and custom_url != default_rasax_model_server_url:
logger.info(
f"Ignoring url '{custom_url}' from 'endpoints.yml' and using "
f"'{default_rasax_model_server_url}' instead."
)
custom_wait_time_pulls = model_endpoint.kwargs.get("wait_time_between_pulls")
return EndpointConfig(
default_rasax_model_server_url,
token=rasa_x_token,
wait_time_between_pulls=custom_wait_time_pulls or 2,
)
def __init__(self, endpoint: EndpointConfig = None) -> None:
if endpoint:
self.endpoint = endpoint
else:
self.endpoint = EndpointConfig(constants.DEFAULT_SERVER_URL)
def _get_model_endpoint(
model_endpoint: Optional[EndpointConfig], rasa_x_token: Text, rasa_x_url: Text
) -> EndpointConfig:
# If you change that, please run a test with Rasa X and speak to the bot
default_rasax_model_server_url = (
f"{rasa_x_url}/projects/default/models/tags/production"
)
model_endpoint = model_endpoint or EndpointConfig()
# Checking if endpoint.yml has existing url, if so give
# warning we are overwriting the endpoint.yml file.
custom_url = model_endpoint.url
if custom_url and custom_url != default_rasax_model_server_url:
logger.info(
f"Ignoring url '{custom_url}' from 'endpoints.yml' and using "
f"'{default_rasax_model_server_url}' instead."
)
custom_wait_time_pulls = model_endpoint.kwargs.get("wait_time_between_pulls")
return EndpointConfig(
default_rasax_model_server_url,
token=rasa_x_token,
wait_time_between_pulls=custom_wait_time_pulls or 2,