Skip to content

Commit

Permalink
Update function comments and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
theanht1 committed Mar 20, 2021
1 parent 6892e37 commit 7ff917f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 47 deletions.
12 changes: 4 additions & 8 deletions rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,9 +1445,7 @@ def _test_data_file_from_payload(
)


def _training_payload_from_json(
request: Request, temp_dir: Path
) -> Dict[Text, Union[Text, bool]]:
def _training_payload_from_json(request: Request, temp_dir: Path) -> Dict[Text, Any]:
logger.debug(
"Extracting JSON payload with Markdown training data from request body."
)
Expand Down Expand Up @@ -1534,9 +1532,7 @@ def _validate_json_training_payload(rjs: Dict):
)


def _training_payload_from_yaml(
request: Request, temp_dir: Path
) -> Dict[Text, Union[Text, bool]]:
def _training_payload_from_yaml(request: Request, temp_dir: Path) -> Dict[Text, Any]:
logger.debug("Extracting YAML training data from request body.")

decoded = request.body.decode(rasa.shared.utils.io.DEFAULT_ENCODING)
Expand Down Expand Up @@ -1572,15 +1568,15 @@ def _validate_yaml_training_payload(yaml_text: Text) -> None:
)


def _extract_core_additional_arguments(request: Request) -> Dict:
def _extract_core_additional_arguments(request: Request) -> Dict[Text, Any]:
return {
"augmentation_factor": rasa.utils.endpoints.int_arg(
request, "augmentation", 50
),
}


def _extract_nlu_additional_arguments(request: Request) -> Dict:
def _extract_nlu_additional_arguments(request: Request) -> Dict[Text, Any]:
return {
"num_threads": rasa.utils.endpoints.int_arg(request, "num_threads", 1),
}
6 changes: 3 additions & 3 deletions rasa/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(self, status: int, message: Text, text: Text) -> None:


def bool_arg(request: Request, name: Text, default: bool = True) -> bool:
"""Return a passed boolean argument of the request or a default.
"""Returns a passed boolean argument of the request or a default.
Checks the `name` parameter of the request if it contains a valid
boolean value. If not, `default` is returned.
Expand All @@ -219,7 +219,7 @@ def bool_arg(request: Request, name: Text, default: bool = True) -> bool:
def float_arg(
request: Request, key: Text, default: Optional[float] = None
) -> Optional[float]:
"""Return a passed argument cast as a float or None.
"""Returns a passed argument cast as a float or None.
Checks the `key` parameter of the request if it contains a valid
float value. If not, `default` is returned.
Expand Down Expand Up @@ -247,7 +247,7 @@ def float_arg(
def int_arg(
request: Request, key: Text, default: Optional[int] = None
) -> Optional[int]:
"""Return a passed argument cast as an int or None.
"""Returns a passed argument cast as an int or None.
Checks the `key` parameter of the request if it contains a valid
int value. If not, `default` is returned.
Expand Down
39 changes: 3 additions & 36 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,16 +641,16 @@ async def test_train_with_yaml(rasa_app: SanicASGITestClient, tmp_path: Path):
async def test_train_with_yaml_with_params(
monkeypatch: MonkeyPatch,
rasa_app: SanicASGITestClient,
tmpdir: pathlib.Path,
tmp_path: Path,
params: Dict,
):
fake_model = Path(tmpdir) / "fake_model.tar.gz"
fake_model = Path(tmp_path) / "fake_model.tar.gz"
fake_model.touch()
fake_model_path = str(fake_model)
future = asyncio.Future()
future.set_result(TrainingResult(model=fake_model_path))
mock_train = Mock(return_value=future)
monkeypatch.setattr(sys.modules["rasa.train"], "train_async", mock_train)
monkeypatch.setattr(rasa.model_training, "train_async", mock_train)

training_data = """
stories: []
Expand Down Expand Up @@ -733,39 +733,6 @@ def test_training_payload_from_yaml_save_to_default_model_directory(
assert payload.get("output") == expected


@pytest.mark.parametrize(
"headers, expected",
[
({}, {"augmentation_factor": 50}),
({"augmentation": "25"}, {"augmentation_factor": 25}),
],
)
def test_training_payload_from_yaml_core_arguments(
headers: Dict, expected: bool, tmp_path: Path
):
request = Mock()
request.body = b""
request.args = headers

payload = rasa.server._training_payload_from_yaml(request, tmp_path)
assert payload.get("core_additional_arguments") == expected


@pytest.mark.parametrize(
"headers, expected",
[({}, {"num_threads": 1}), ({"num_threads": "2"}, {"num_threads": 2})],
)
def test_training_payload_from_yaml_nlu_arguments(
headers: Dict, expected: bool, tmp_path: Path
):
request = Mock()
request.body = b""
request.args = headers

payload = rasa.server._training_payload_from_yaml(request, tmp_path)
assert payload.get("nlu_additional_arguments") == expected


@pytest.mark.trains_model
async def test_train_missing_config(rasa_app: SanicASGITestClient):
payload = dict(domain="domain data", config=None)
Expand Down

0 comments on commit 7ff917f

Please sign in to comment.