Skip to content

Commit

Permalink
Update function comments and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
theanht1 committed Mar 20, 2021
1 parent 6892e37 commit 24a0db3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 39 deletions.
6 changes: 3 additions & 3 deletions rasa/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(self, status: int, message: Text, text: Text) -> None:


def bool_arg(request: Request, name: Text, default: bool = True) -> bool:
"""Return a passed boolean argument of the request or a default.
"""Returns a passed boolean argument of the request or a default.
Checks the `name` parameter of the request if it contains a valid
boolean value. If not, `default` is returned.
Expand All @@ -219,7 +219,7 @@ def bool_arg(request: Request, name: Text, default: bool = True) -> bool:
def float_arg(
request: Request, key: Text, default: Optional[float] = None
) -> Optional[float]:
"""Return a passed argument cast as a float or None.
"""Returns a passed argument cast as a float or None.
Checks the `key` parameter of the request if it contains a valid
float value. If not, `default` is returned.
Expand Down Expand Up @@ -247,7 +247,7 @@ def float_arg(
def int_arg(
request: Request, key: Text, default: Optional[int] = None
) -> Optional[int]:
"""Return a passed argument cast as an int or None.
"""Returns a passed argument cast as an int or None.
Checks the `key` parameter of the request if it contains a valid
int value. If not, `default` is returned.
Expand Down
39 changes: 3 additions & 36 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,16 +641,16 @@ async def test_train_with_yaml(rasa_app: SanicASGITestClient, tmp_path: Path):
async def test_train_with_yaml_with_params(
monkeypatch: MonkeyPatch,
rasa_app: SanicASGITestClient,
tmpdir: pathlib.Path,
tmp_path: Path,
params: Dict,
):
fake_model = Path(tmpdir) / "fake_model.tar.gz"
fake_model = Path(tmp_path) / "fake_model.tar.gz"
fake_model.touch()
fake_model_path = str(fake_model)
future = asyncio.Future()
future.set_result(TrainingResult(model=fake_model_path))
mock_train = Mock(return_value=future)
monkeypatch.setattr(sys.modules["rasa.train"], "train_async", mock_train)
monkeypatch.setattr(rasa.model_training, "train_async", mock_train)

training_data = """
stories: []
Expand Down Expand Up @@ -733,39 +733,6 @@ def test_training_payload_from_yaml_save_to_default_model_directory(
assert payload.get("output") == expected


@pytest.mark.parametrize(
"headers, expected",
[
({}, {"augmentation_factor": 50}),
({"augmentation": "25"}, {"augmentation_factor": 25}),
],
)
def test_training_payload_from_yaml_core_arguments(
headers: Dict, expected: bool, tmp_path: Path
):
request = Mock()
request.body = b""
request.args = headers

payload = rasa.server._training_payload_from_yaml(request, tmp_path)
assert payload.get("core_additional_arguments") == expected


@pytest.mark.parametrize(
"headers, expected",
[({}, {"num_threads": 1}), ({"num_threads": "2"}, {"num_threads": 2})],
)
def test_training_payload_from_yaml_nlu_arguments(
headers: Dict, expected: bool, tmp_path: Path
):
request = Mock()
request.body = b""
request.args = headers

payload = rasa.server._training_payload_from_yaml(request, tmp_path)
assert payload.get("nlu_additional_arguments") == expected


@pytest.mark.trains_model
async def test_train_missing_config(rasa_app: SanicASGITestClient):
payload = dict(domain="domain data", config=None)
Expand Down

0 comments on commit 24a0db3

Please sign in to comment.