Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add core and nlu additional params to training endpoint #8132

Merged
merged 14 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/4596.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add `augmentation` and `num_threads` arguments to API `POST /model/train`

Fix boolean casting issue for `force_training` and `save_to_default_model_directory` arguments
12 changes: 12 additions & 0 deletions docs/static/spec/rasa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,18 @@ paths:
type: boolean
default: False
description: Force a model training even if the data has not changed
- in: query
wochinge marked this conversation as resolved.
Show resolved Hide resolved
name: augmentation
schema:
type: string
default: 50
description: How much data augmentation to use during training
- in: query
name: num_threads
schema:
type: string
default: 1
description: Maximum amount of threads to use when training
- $ref: '#/components/parameters/callback_url'
requestBody:
required: true
Expand Down
27 changes: 22 additions & 5 deletions rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,6 @@ async def status(request: Request):
@ensure_loaded_agent(app)
async def retrieve_tracker(request: Request, conversation_id: Text):
"""Get a dump of a conversation's tracker including its events."""

verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART)
until_time = rasa.utils.endpoints.float_arg(request, "until")

Expand Down Expand Up @@ -1469,7 +1468,7 @@ def _training_payload_from_json(
model_output_directory = str(temp_dir)
if request_payload.get(
"save_to_default_model_directory",
request.args.get("save_to_default_model_directory", True),
rasa.utils.endpoints.bool_arg(request, "save_to_default_model_directory", True),
):
model_output_directory = DEFAULT_MODELS_PATH

Expand All @@ -1479,8 +1478,10 @@ def _training_payload_from_json(
training_files=str(temp_dir),
output=model_output_directory,
force_training=request_payload.get(
"force", request.args.get("force_training", False)
"force", rasa.utils.endpoints.bool_arg(request, "force_training", False)
),
core_additional_arguments=_extract_core_additional_arguments(request),
nlu_additional_arguments=_extract_nlu_additional_arguments(request),
)


Expand Down Expand Up @@ -1532,15 +1533,17 @@ def _training_payload_from_yaml(
rasa.shared.utils.io.write_text_file(decoded, training_data)

model_output_directory = str(temp_dir)
if request.args.get("save_to_default_model_directory", True):
if rasa.utils.endpoints.bool_arg(request, "save_to_default_model_directory", True):
model_output_directory = DEFAULT_MODELS_PATH

return dict(
domain=str(training_data),
config=str(training_data),
training_files=str(temp_dir),
output=model_output_directory,
force_training=request.args.get("force_training", False),
force_training=rasa.utils.endpoints.bool_arg(request, "force_training", False),
core_additional_arguments=_extract_core_additional_arguments(request),
nlu_additional_arguments=_extract_nlu_additional_arguments(request),
)


Expand All @@ -1554,3 +1557,17 @@ def _validate_yaml_training_payload(yaml_text: Text) -> None:
f"The request body does not contain valid YAML. Error: {e}",
help_url=DOCS_URL_TRAINING_DATA,
)


def _extract_core_additional_arguments(request: Request) -> Dict:
return {
"augmentation_factor": rasa.utils.endpoints.int_arg(
request, "augmentation", 50
),
}


def _extract_nlu_additional_arguments(request: Request) -> Dict:
return {
"num_threads": rasa.utils.endpoints.int_arg(request, "num_threads", 1),
}
52 changes: 48 additions & 4 deletions rasa/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,35 @@ def bool_arg(request: Request, name: Text, default: bool = True) -> bool:
"""Return a passed boolean argument of the request or a default.

Checks the `name` parameter of the request if it contains a valid
boolean value. If not, `default` is returned."""
boolean value. If not, `default` is returned.

return request.args.get(name, str(default)).lower() == "true"
Args:
request: Sanic request.
name: Name of argument.
default: Default value for `name` argument.

Returns:
A bool value if `name` is a valid boolean, `default` otherwise.
"""
return str(request.args.get(name, default)).lower() == "true"


def float_arg(
request: Request, key: Text, default: Optional[float] = None
) -> Optional[float]:
"""Return a passed argument cast as a float or None.

Checks the `name` parameter of the request if it contains a valid
float value. If not, `None` is returned."""
Checks the `key` parameter of the request if it contains a valid
wochinge marked this conversation as resolved.
Show resolved Hide resolved
float value. If not, `default` is returned.

Args:
request: Sanic request.
key: Name of argument.
default: Default value for `key` argument.

Returns:
A float value if `key` is a valid float, `default` otherwise.
"""
arg = request.args.get(key, default)

if arg is default:
Expand All @@ -226,3 +242,31 @@ def float_arg(
except (ValueError, TypeError):
logger.warning(f"Failed to convert '{arg}' to float.")
return default


def int_arg(
wochinge marked this conversation as resolved.
Show resolved Hide resolved
request: Request, key: Text, default: Optional[int] = None
) -> Optional[int]:
"""Return a passed argument cast as an int or None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also describe the other paramaters and then return values using Args: and Returns:?

Suggested change
"""Return a passed argument cast as an int or None.
"""Returns a passed argument cast as an int or None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added for bool_arg, float_arg and int_arg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Minor comment: Could you please change Return to Returns? We try to use descriptive docstrings in the codebase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I forgot to update this one


Checks the `key` parameter of the request if it contains a valid
int value. If not, `default` is returned.

Args:
request: Sanic request.
key: Name of argument.
default: Default value for `key` argument.

Returns:
An int value if `key` is a valid integer, `default` otherwise.
"""
arg = request.args.get(key, default)

if arg is default:
return arg

try:
return int(str(arg))
except (ValueError, TypeError):
logger.warning(f"Failed to convert '{arg}' to int.")
return default
81 changes: 81 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import os
import sys
import time
import urllib.parse
import uuid
Expand Down Expand Up @@ -634,6 +635,52 @@ async def test_train_with_yaml(rasa_app: SanicASGITestClient, tmp_path: Path):
assert_trained_model(response.body, tmp_path)


@pytest.mark.parametrize(
"params", [{}, {"augmentation": 20, "num_threads": 2, "force_training": True}]
)
async def test_train_with_yaml_with_params(
monkeypatch: MonkeyPatch,
rasa_app: SanicASGITestClient,
tmpdir: pathlib.Path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tmpdir is of type LocalPath which is slightly different. I suggest using tmp_path

Suggested change
tmpdir: pathlib.Path,
tmp_path: pathlib.Path,

params: Dict,
):
fake_model = Path(tmpdir) / "fake_model.tar.gz"
fake_model.touch()
fake_model_path = str(fake_model)
future = asyncio.Future()
future.set_result(TrainingResult(model=fake_model_path))
mock_train = Mock(return_value=future)
monkeypatch.setattr(sys.modules["rasa.train"], "train_async", mock_train)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we recently moved this:

Suggested change
monkeypatch.setattr(sys.modules["rasa.train"], "train_async", mock_train)
monkeypatch.setattr(rasa.model_training, "train_async", mock_train)


training_data = """
stories: []
rules: []
intents: []
nlu: []
responses: {}
language: en
policies: []
pipeline: []
"""
_, response = await rasa_app.post(
"/model/train",
data=training_data,
params=params,
headers={"Content-type": rasa.server.YAML_CONTENT_TYPE},
)

assert response.status == HTTPStatus.OK
assert mock_train.call_count == 1
args, kwargs = mock_train.call_args_list[0]
assert kwargs["core_additional_arguments"]["augmentation_factor"] == params.get(
"augmentation", 50
)
assert kwargs["nlu_additional_arguments"]["num_threads"] == params.get(
"num_threads", 1
)
assert kwargs["force_training"] == params.get("force_training", False)


async def test_train_with_invalid_yaml(rasa_app: SanicASGITestClient):
invalid_yaml = """
rules:
Expand Down Expand Up @@ -686,6 +733,40 @@ def test_training_payload_from_yaml_save_to_default_model_directory(
assert payload.get("output") == expected


@pytest.mark.parametrize(
"headers, expected",
[
({}, {"augmentation_factor": 50}),
({"augmentation": "25"}, {"augmentation_factor": 25}),
],
)
def test_training_payload_from_yaml_core_arguments(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more robust to test with an actual http request (e.g. like done in this test function). This has the advantange that the test focuses on the functionality and less on the specific implementation (_training_payload_from_yaml)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use monkeypatch to mock the actual training, you can test whether the parameters were passed correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! I think we can even delete this and the ones after this then.

headers: Dict, expected: bool, tmp_path: Path
):
request = Mock()
request.body = b""
request.args = headers

payload = rasa.server._training_payload_from_yaml(request, tmp_path)
assert payload.get("core_additional_arguments") == expected


@pytest.mark.parametrize(
"headers, expected",
[({}, {"num_threads": 1}), ({"num_threads": "2"}, {"num_threads": 2})],
)
def test_training_payload_from_yaml_nlu_arguments(
headers: Dict, expected: bool, tmp_path: Path
):
request = Mock()
request.body = b""
request.args = headers

payload = rasa.server._training_payload_from_yaml(request, tmp_path)
assert payload.get("nlu_additional_arguments") == expected


@pytest.mark.trains_model
wochinge marked this conversation as resolved.
Show resolved Hide resolved
async def test_train_missing_config(rasa_app: SanicASGITestClient):
payload = dict(domain="domain data", config=None)

Expand Down
48 changes: 47 additions & 1 deletion tests/utils/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Text
from typing import Text, Optional, Union
from unittest.mock import Mock

import pytest
from aioresponses import aioresponses
Expand Down Expand Up @@ -144,3 +145,48 @@ def test_read_endpoint_config(filename: Text, endpoint_type: Text):
def test_read_endpoint_config_not_found(filename: Text, endpoint_type: Text):
conf = endpoint_utils.read_endpoint_config(filename, endpoint_type)
assert conf is None


@pytest.mark.parametrize(
"value, default, expected_result",
[
(None, True, True),
(False, True, False),
("false", True, False),
("true", False, True),
],
)
def test_bool_arg(
value: Optional[Union[bool, str]], default: bool, expected_result: bool
):
request = Mock()
request.args = {}
if value is not None:
request.args = {"key": value}
assert endpoint_utils.bool_arg(request, "key", default) == expected_result


@pytest.mark.parametrize(
"value, default, expected_result",
[(None, 0.5, 0.5), (0.5, None, 0.5), ("0.5", 0, 0.5), ("a", 0.5, 0.5)],
)
def test_float_arg(
value: Optional[Union[float, str]], default: float, expected_result: float
):
request = Mock()
request.args = {}
if value is not None:
request.args = {"key": value}
assert endpoint_utils.float_arg(request, "key", default) == expected_result


@pytest.mark.parametrize(
"value, default, expected_result",
[(None, 0, 0), (1, 0, 1), ("1", 0, 1), ("a", 0, 0)],
)
def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result: int):
request = Mock()
request.args = {}
if value is not None:
request.args = {"key": value}
assert endpoint_utils.int_arg(request, "key", default) == expected_result