Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add core and nlu additional params to training endpoint #8132

Merged
merged 14 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/static/spec/rasa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,18 @@ paths:
type: boolean
default: False
description: Force a model training even if the data has not changed
- in: query
wochinge marked this conversation as resolved.
Show resolved Hide resolved
name: augmentation
schema:
type: string
default: 50
description: How much data augmentation to use during training
- in: query
name: num_threads
schema:
type: string
default: 1
description: Maximum amount of threads to use when training
- $ref: '#/components/parameters/callback_url'
requestBody:
required: true
Expand Down
41 changes: 26 additions & 15 deletions rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from rasa.core.utils import AvailableEndpoints
from rasa.nlu.emulators.no_emulator import NoEmulator
from rasa.nlu.test import run_evaluation, CVEvaluationResult
from rasa.utils.endpoints import EndpointConfig
from rasa.utils.endpoints import EndpointConfig, bool_arg, float_arg, int_arg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have some internal code convention where we say that don't want to import functions directly in favor of explictly referring to it by module. So it would be great if you could also refer to bool_arg etc. using rasa.utils.endpoints.bool_arg 🙌🏻

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


if TYPE_CHECKING:
from ssl import SSLContext
Expand Down Expand Up @@ -699,9 +699,8 @@ async def status(request: Request):
@ensure_loaded_agent(app)
async def retrieve_tracker(request: Request, conversation_id: Text):
"""Get a dump of a conversation's tracker including its events."""

verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART)
until_time = rasa.utils.endpoints.float_arg(request, "until")
until_time = float_arg(request, "until")

tracker = await app.agent.create_processor().fetch_tracker_with_initial_session(
conversation_id
Expand Down Expand Up @@ -745,9 +744,7 @@ async def append_events(request: Request, conversation_id: Text):

output_channel = _get_output_channel(request, tracker)

if rasa.utils.endpoints.bool_arg(
request, EXECUTE_SIDE_EFFECTS_QUERY_KEY, False
):
if bool_arg(request, EXECUTE_SIDE_EFFECTS_QUERY_KEY, False):
await processor.execute_side_effects(
events, tracker, output_channel
)
Expand Down Expand Up @@ -823,10 +820,8 @@ async def replace_events(request: Request, conversation_id: Text):
@ensure_conversation_exists()
async def retrieve_story(request: Request, conversation_id: Text):
"""Get an end-to-end story corresponding to this conversation."""
until_time = rasa.utils.endpoints.float_arg(request, "until")
fetch_all_sessions = rasa.utils.endpoints.bool_arg(
request, "all_sessions", default=False
)
until_time = float_arg(request, "until")
fetch_all_sessions = bool_arg(request, "all_sessions", default=False)

try:
stories = get_test_stories(
Expand Down Expand Up @@ -1091,7 +1086,7 @@ async def evaluate_stories(

test_data = _test_data_file_from_payload(request, temporary_directory, ".md")

use_e2e = rasa.utils.endpoints.bool_arg(request, "e2e", default=False)
use_e2e = bool_arg(request, "e2e", default=False)

try:
evaluation = await test(
Expand Down Expand Up @@ -1469,7 +1464,7 @@ def _training_payload_from_json(
model_output_directory = str(temp_dir)
if request_payload.get(
"save_to_default_model_directory",
request.args.get("save_to_default_model_directory", True),
bool_arg(request, "save_to_default_model_directory", True),
):
model_output_directory = DEFAULT_MODELS_PATH

Expand All @@ -1479,8 +1474,10 @@ def _training_payload_from_json(
training_files=str(temp_dir),
output=model_output_directory,
force_training=request_payload.get(
"force", request.args.get("force_training", False)
"force", bool_arg(request, "force_training", False)
),
core_additional_arguments=_extract_core_additional_arguments(request),
nlu_additional_arguments=_extract_nlu_additional_arguments(request),
)


Expand Down Expand Up @@ -1532,15 +1529,17 @@ def _training_payload_from_yaml(
rasa.shared.utils.io.write_text_file(decoded, training_data)

model_output_directory = str(temp_dir)
if request.args.get("save_to_default_model_directory", True):
if bool_arg(request, "save_to_default_model_directory", True):
model_output_directory = DEFAULT_MODELS_PATH

return dict(
domain=str(training_data),
config=str(training_data),
training_files=str(temp_dir),
output=model_output_directory,
force_training=request.args.get("force_training", False),
force_training=bool_arg(request, "force_training", False),
core_additional_arguments=_extract_core_additional_arguments(request),
nlu_additional_arguments=_extract_nlu_additional_arguments(request),
)


Expand All @@ -1554,3 +1553,15 @@ def _validate_yaml_training_payload(yaml_text: Text) -> None:
f"The request body does not contain valid YAML. Error: {e}",
help_url=DOCS_URL_TRAINING_DATA,
)


def _extract_core_additional_arguments(request: Request) -> Dict:
return {
"augmentation_factor": int_arg(request, "augmentation", 50),
}


def _extract_nlu_additional_arguments(request: Request) -> Dict:
return {
"num_threads": int_arg(request, "num_threads", 1),
}
32 changes: 26 additions & 6 deletions rasa/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,19 @@ def bool_arg(request: Request, name: Text, default: bool = True) -> bool:
"""Return a passed boolean argument of the request or a default.

Checks the `name` parameter of the request if it contains a valid
boolean value. If not, `default` is returned."""

return request.args.get(name, str(default)).lower() == "true"
boolean value. If not, `default` is returned.
"""
return str(request.args.get(name, default)).lower() == "true"


def float_arg(
request: Request, key: Text, default: Optional[float] = None
) -> Optional[float]:
"""Return a passed argument cast as a float or None.

Checks the `name` parameter of the request if it contains a valid
float value. If not, `None` is returned."""

Checks the `key` parameter of the request if it contains a valid
wochinge marked this conversation as resolved.
Show resolved Hide resolved
float value. If not, `default` is returned.
"""
arg = request.args.get(key, default)

if arg is default:
Expand All @@ -226,3 +226,23 @@ def float_arg(
except (ValueError, TypeError):
logger.warning(f"Failed to convert '{arg}' to float.")
return default


def int_arg(
wochinge marked this conversation as resolved.
Show resolved Hide resolved
request: Request, key: Text, default: Optional[int] = None
) -> Optional[int]:
"""Return a passed argument cast as an int or None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also describe the other paramaters and then return values using Args: and Returns:?

Suggested change
"""Return a passed argument cast as an int or None.
"""Returns a passed argument cast as an int or None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added for bool_arg, float_arg and int_arg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Minor comment: Could you please change Return to Returns? We try to use descriptive docstrings in the codebase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I forgot to update this one


Checks the `key` parameter of the request if it contains a valid
int value. If not, `default` is returned.
"""
arg = request.args.get(key, default)

if arg is default:
return arg

try:
return int(str(arg))
except (ValueError, TypeError):
logger.warning(f"Failed to convert '{arg}' to int.")
return default
33 changes: 33 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,39 @@ def test_training_payload_from_yaml_save_to_default_model_directory(
assert payload.get("output") == expected


@pytest.mark.parametrize(
"headers, expected",
[
({}, {"augmentation_factor": 50}),
({"augmentation": "25"}, {"augmentation_factor": 25}),
],
)
def test_training_payload_from_yaml_core_arguments(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more robust to test with an actual http request (e.g. like done in this test function). This has the advantange that the test focuses on the functionality and less on the specific implementation (_training_payload_from_yaml)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use monkeypatch to mock the actual training, you can test whether the parameters were passed correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! I think we can even delete this and the ones after this then.

headers: Dict, expected: bool, tmp_path: Path
):
request = Mock()
request.body = b""
request.args = headers

payload = rasa.server._training_payload_from_yaml(request, tmp_path)
assert payload.get("core_additional_arguments") == expected


@pytest.mark.parametrize(
"headers, expected",
[({}, {"num_threads": 1}), ({"num_threads": "2"}, {"num_threads": 2})],
)
def test_training_payload_from_yaml_nlu_arguments(
headers: Dict, expected: bool, tmp_path: Path
):
request = Mock()
request.body = b""
request.args = headers

payload = rasa.server._training_payload_from_yaml(request, tmp_path)
assert payload.get("nlu_additional_arguments") == expected


@pytest.mark.trains_model
wochinge marked this conversation as resolved.
Show resolved Hide resolved
async def test_train_missing_config(rasa_app: SanicASGITestClient):
payload = dict(domain="domain data", config=None)
Expand Down