-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add core and nlu additional params to training endpoint #8132
Conversation
@wochinge please review this, I'll update the changelog later |
Thanks for submitting a pull request 🚀 @amn41 will take a look at it as soon as possible ✨ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this clean and concise change! 🚀
Can you please also add a changelog entry?
rasa/server.py
Outdated
@@ -69,7 +69,7 @@ | |||
from rasa.core.utils import AvailableEndpoints | |||
from rasa.nlu.emulators.no_emulator import NoEmulator | |||
from rasa.nlu.test import run_evaluation, CVEvaluationResult | |||
from rasa.utils.endpoints import EndpointConfig | |||
from rasa.utils.endpoints import EndpointConfig, bool_arg, float_arg, int_arg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have some internal code convention where we say that don't want to import functions directly in favor of explictly referring to it by module. So it would be great if you could also refer to bool_arg
etc. using rasa.utils.endpoints.bool_arg
🙌🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
rasa/utils/endpoints.py
Outdated
def int_arg( | ||
request: Request, key: Text, default: Optional[int] = None | ||
) -> Optional[int]: | ||
"""Return a passed argument cast as an int or None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please also describe the other paramaters and then return values using Args:
and Returns
:?
"""Return a passed argument cast as an int or None. | |
"""Returns a passed argument cast as an int or None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added for bool_arg
, float_arg
and int_arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome. Minor comment: Could you please change Return
to Returns
? We try to use descriptive docstrings in the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I forgot to update this one
* Use argument casting functions from endpoints module * Update docstring * Add tests for argument casting functions * Add changelog
tests/utils/test_endpoints.py
Outdated
("true", False, True), | ||
], | ||
) | ||
def test_bool_arg(value, default, expected_result): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please add type annotations to the test parameters? (same for the other test functions) 🙌🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
tests/test_server.py
Outdated
({"augmentation": "25"}, {"augmentation_factor": 25}), | ||
], | ||
) | ||
def test_training_payload_from_yaml_core_arguments( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be more robust to test with an actual http request (e.g. like done in this test function). This has the advantange that the test focuses on the functionality and less on the specific implementation (_training_payload_from_yaml
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you use monkeypatch
to mock the actual training, you can test whether the parameters were passed correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome! I think we can even delete this and the ones after this then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few tiny comments 💯
rasa/utils/endpoints.py
Outdated
def int_arg( | ||
request: Request, key: Text, default: Optional[int] = None | ||
) -> Optional[int]: | ||
"""Return a passed argument cast as an int or None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome. Minor comment: Could you please change Return
to Returns
? We try to use descriptive docstrings in the codebase.
tests/test_server.py
Outdated
async def test_train_with_yaml_with_params( | ||
monkeypatch: MonkeyPatch, | ||
rasa_app: SanicASGITestClient, | ||
tmpdir: pathlib.Path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tmpdir
is of type LocalPath
which is slightly different. I suggest using tmp_path
tmpdir: pathlib.Path, | |
tmp_path: pathlib.Path, |
tests/test_server.py
Outdated
future = asyncio.Future() | ||
future.set_result(TrainingResult(model=fake_model_path)) | ||
mock_train = Mock(return_value=future) | ||
monkeypatch.setattr(sys.modules["rasa.train"], "train_async", mock_train) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we recently moved this:
monkeypatch.setattr(sys.modules["rasa.train"], "train_async", mock_train) | |
monkeypatch.setattr(rasa.model_training, "train_async", mock_train) |
tests/test_server.py
Outdated
({"augmentation": "25"}, {"augmentation_factor": 25}), | ||
], | ||
) | ||
def test_training_payload_from_yaml_core_arguments( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome! I think we can even delete this and the ones after this then.
24a0db3
to
7ff917f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
Cool, thanks @wochinge, it's my pleasure to have contributions to this project |
Proposed changes:
Status (please check what you already did):
black
(please check Readme for instructions)