@@ -97,12 +97,12 @@ class ErrorResponse(Exception):
97
97
"""Common exception to handle failing API requests."""
98
98
99
99
def __init__ (
100
- self ,
101
- status : Union [int , HTTPStatus ],
102
- reason : Text ,
103
- message : Text ,
104
- details : Any = None ,
105
- help_url : Optional [Text ] = None ,
100
+ self ,
101
+ status : Union [int , HTTPStatus ],
102
+ reason : Text ,
103
+ message : Text ,
104
+ details : Any = None ,
105
+ help_url : Optional [Text ] = None ,
106
106
) -> None :
107
107
"""Creates error.
108
108
@@ -133,7 +133,7 @@ def _docs(sub_url: Text) -> Text:
133
133
134
134
135
135
def ensure_loaded_agent (
136
- app : Sanic , require_core_is_ready : bool = False
136
+ app : Sanic , require_core_is_ready : bool = False
137
137
) -> Callable [[Callable ], Callable [..., Any ]]:
138
138
"""Wraps a request handler ensuring there is a loaded and usable agent.
139
139
@@ -181,7 +181,7 @@ def decorated(request: Request, *args: Any, **kwargs: Any) -> "SanicResponse":
181
181
182
182
183
183
def requires_auth (
184
- app : Sanic , token : Optional [Text ] = None
184
+ app : Sanic , token : Optional [Text ] = None
185
185
) -> Callable [["SanicView" ], "SanicView" ]:
186
186
"""Wraps a request handler with token authentication."""
187
187
@@ -200,7 +200,7 @@ def conversation_id_from_args(args: Any, kwargs: Any) -> Optional[Text]:
200
200
return None
201
201
202
202
async def sufficient_scope (
203
- request : Request , * args : Any , ** kwargs : Any
203
+ request : Request , * args : Any , ** kwargs : Any
204
204
) -> Optional [bool ]:
205
205
# This is a coroutine since `sanic-jwt==1.6`
206
206
jwt_data = await rasa .utils .common .call_potential_coroutine (
@@ -222,7 +222,7 @@ async def sufficient_scope(
222
222
223
223
@wraps (f )
224
224
async def decorated (
225
- request : Request , * args : Any , ** kwargs : Any
225
+ request : Request , * args : Any , ** kwargs : Any
226
226
) -> response .HTTPResponse :
227
227
228
228
provided = request .args .get ("token" , None )
@@ -232,7 +232,7 @@ async def decorated(
232
232
result = f (request , * args , ** kwargs )
233
233
return await result if isawaitable (result ) else result
234
234
elif app .config .get (
235
- "USE_JWT"
235
+ "USE_JWT"
236
236
) and await rasa .utils .common .call_potential_coroutine (
237
237
# This is a coroutine since `sanic-jwt==1.6`
238
238
request .app .ctx .auth .is_authenticated (request )
@@ -267,7 +267,7 @@ async def decorated(
267
267
268
268
269
269
def event_verbosity_parameter (
270
- request : Request , default_verbosity : EventVerbosity
270
+ request : Request , default_verbosity : EventVerbosity
271
271
) -> EventVerbosity :
272
272
"""Create `EventVerbosity` object using request params if present."""
273
273
event_verbosity_str = request .args .get (
@@ -287,10 +287,10 @@ def event_verbosity_parameter(
287
287
288
288
289
289
def get_test_stories (
290
- processor : "MessageProcessor" ,
291
- conversation_id : Text ,
292
- until_time : Optional [float ],
293
- fetch_all_sessions : bool = False ,
290
+ processor : "MessageProcessor" ,
291
+ conversation_id : Text ,
292
+ until_time : Optional [float ],
293
+ fetch_all_sessions : bool = False ,
294
294
) -> Text :
295
295
"""Retrieves test stories from `processor` for all conversation sessions for
296
296
`conversation_id`.
@@ -336,10 +336,10 @@ def get_test_stories(
336
336
337
337
338
338
async def update_conversation_with_events (
339
- conversation_id : Text ,
340
- processor : "MessageProcessor" ,
341
- domain : Domain ,
342
- events : List [Event ],
339
+ conversation_id : Text ,
340
+ processor : "MessageProcessor" ,
341
+ domain : Domain ,
342
+ events : List [Event ],
343
343
) -> DialogueStateTracker :
344
344
"""Fetches or creates a tracker for `conversation_id` and appends `events` to it.
345
345
@@ -398,10 +398,10 @@ async def authenticate(_: Request) -> NoReturn:
398
398
399
399
400
400
def create_ssl_context (
401
- ssl_certificate : Optional [Text ],
402
- ssl_keyfile : Optional [Text ],
403
- ssl_ca_file : Optional [Text ] = None ,
404
- ssl_password : Optional [Text ] = None ,
401
+ ssl_certificate : Optional [Text ],
402
+ ssl_keyfile : Optional [Text ],
403
+ ssl_ca_file : Optional [Text ] = None ,
404
+ ssl_password : Optional [Text ] = None ,
405
405
) -> Optional ["SSLContext" ]:
406
406
"""Create an SSL context if a proper certificate is passed.
407
407
@@ -460,10 +460,10 @@ def _create_emulator(mode: Optional[Text]) -> Emulator:
460
460
461
461
462
462
async def _load_agent (
463
- model_path : Optional [Text ] = None ,
464
- model_server : Optional [EndpointConfig ] = None ,
465
- remote_storage : Optional [Text ] = None ,
466
- endpoints : Optional [AvailableEndpoints ] = None ,
463
+ model_path : Optional [Text ] = None ,
464
+ model_server : Optional [EndpointConfig ] = None ,
465
+ remote_storage : Optional [Text ] = None ,
466
+ endpoints : Optional [AvailableEndpoints ] = None ,
467
467
) -> Agent :
468
468
try :
469
469
loaded_agent = await rasa .core .agent .load_agent (
@@ -492,7 +492,7 @@ async def _load_agent(
492
492
493
493
494
494
def configure_cors (
495
- app : Sanic , cors_origins : Union [Text , List [Text ], None ] = ""
495
+ app : Sanic , cors_origins : Union [Text , List [Text ], None ] = ""
496
496
) -> None :
497
497
"""Configure CORS origins for the given app."""
498
498
@@ -533,7 +533,7 @@ def async_if_callback_url(f: Callable[..., Coroutine]) -> Callable:
533
533
534
534
@wraps (f )
535
535
async def decorated_function (
536
- request : Request , * args : Any , ** kwargs : Any
536
+ request : Request , * args : Any , ** kwargs : Any
537
537
) -> HTTPResponse :
538
538
callback_url = request .args .get ("callback_url" )
539
539
# Only process request asynchronously if the user specified a `callback_url`
@@ -595,7 +595,7 @@ def run_in_thread(f: Callable[..., Coroutine]) -> Callable:
595
595
596
596
@wraps (f )
597
597
async def decorated_function (
598
- request : Request , * args : Any , ** kwargs : Any
598
+ request : Request , * args : Any , ** kwargs : Any
599
599
) -> HTTPResponse :
600
600
# Use a sync wrapper for our `async` function as `run_in_executor` only supports
601
601
# sync functions
@@ -628,13 +628,13 @@ async def decorated_function(*args: Any, **kwargs: Any) -> HTTPResponse:
628
628
629
629
630
630
def create_app (
631
- agent : Optional ["Agent" ] = None ,
632
- cors_origins : Union [Text , List [Text ], None ] = "*" ,
633
- auth_token : Optional [Text ] = None ,
634
- response_timeout : int = DEFAULT_RESPONSE_TIMEOUT ,
635
- jwt_secret : Optional [Text ] = None ,
636
- jwt_method : Text = "HS256" ,
637
- endpoints : Optional [AvailableEndpoints ] = None ,
631
+ agent : Optional ["Agent" ] = None ,
632
+ cors_origins : Union [Text , List [Text ], None ] = "*" ,
633
+ auth_token : Optional [Text ] = None ,
634
+ response_timeout : int = DEFAULT_RESPONSE_TIMEOUT ,
635
+ jwt_secret : Optional [Text ] = None ,
636
+ jwt_method : Text = "HS256" ,
637
+ endpoints : Optional [AvailableEndpoints ] = None ,
638
638
) -> Sanic :
639
639
"""Class representing a Rasa HTTP server."""
640
640
app = Sanic (__name__ )
@@ -662,7 +662,7 @@ def create_app(
662
662
663
663
@app .exception (ErrorResponse )
664
664
async def handle_error_response (
665
- request : Request , exception : ErrorResponse
665
+ request : Request , exception : ErrorResponse
666
666
) -> HTTPResponse :
667
667
return response .json (exception .error_info , status = exception .status )
668
668
@@ -740,7 +740,7 @@ async def append_events(request: Request, conversation_id: Text) -> HTTPResponse
740
740
output_channel = _get_output_channel (request , tracker )
741
741
742
742
if rasa .utils .endpoints .bool_arg (
743
- request , EXECUTE_SIDE_EFFECTS_QUERY_KEY , False
743
+ request , EXECUTE_SIDE_EFFECTS_QUERY_KEY , False
744
744
):
745
745
await processor .execute_side_effects (
746
746
events , tracker , output_channel
@@ -1071,7 +1071,7 @@ async def train(request: Request, temporary_directory: Path) -> HTTPResponse:
1071
1071
@ensure_loaded_agent (app , require_core_is_ready = True )
1072
1072
@inject_temp_dir
1073
1073
async def evaluate_stories (
1074
- request : Request , temporary_directory : Path
1074
+ request : Request , temporary_directory : Path
1075
1075
) -> HTTPResponse :
1076
1076
"""Evaluate stories against the currently loaded model."""
1077
1077
validate_request_body (
@@ -1103,7 +1103,7 @@ async def evaluate_stories(
1103
1103
@run_in_thread
1104
1104
@inject_temp_dir
1105
1105
async def evaluate_intents (
1106
- request : Request , temporary_directory : Path
1106
+ request : Request , temporary_directory : Path
1107
1107
) -> HTTPResponse :
1108
1108
"""Evaluate intents against a Rasa model."""
1109
1109
validate_request_body (
@@ -1151,7 +1151,7 @@ async def evaluate_intents(
1151
1151
)
1152
1152
1153
1153
async def _evaluate_model_using_test_set (
1154
- model_path : Text , test_data_file : Text
1154
+ model_path : Text , test_data_file : Text
1155
1155
) -> Dict :
1156
1156
logger .info ("Starting model evaluation using test set." )
1157
1157
@@ -1203,9 +1203,9 @@ async def _cross_validate(data_file: Text, config_file: Text, folds: int) -> Dic
1203
1203
return evaluation_results
1204
1204
1205
1205
def _get_evaluation_results (
1206
- intent_report : CVEvaluationResult ,
1207
- entity_report : CVEvaluationResult ,
1208
- response_selector_report : CVEvaluationResult ,
1206
+ intent_report : CVEvaluationResult ,
1207
+ entity_report : CVEvaluationResult ,
1208
+ response_selector_report : CVEvaluationResult ,
1209
1209
) -> Dict [Text , Any ]:
1210
1210
eval_name_mapping = {
1211
1211
"intent_evaluation" : intent_report ,
@@ -1366,7 +1366,7 @@ async def get_domain(request: Request) -> HTTPResponse:
1366
1366
1367
1367
1368
1368
def _get_output_channel (
1369
- request : Request , tracker : Optional [DialogueStateTracker ]
1369
+ request : Request , tracker : Optional [DialogueStateTracker ]
1370
1370
) -> OutputChannel :
1371
1371
"""Returns the `OutputChannel` which should be used for the bot's responses.
1372
1372
@@ -1381,8 +1381,8 @@ def _get_output_channel(
1381
1381
requested_output_channel = request .args .get (OUTPUT_CHANNEL_QUERY_KEY )
1382
1382
1383
1383
if (
1384
- requested_output_channel == USE_LATEST_INPUT_CHANNEL_AS_OUTPUT_CHANNEL
1385
- and tracker
1384
+ requested_output_channel == USE_LATEST_INPUT_CHANNEL_AS_OUTPUT_CHANNEL
1385
+ and tracker
1386
1386
):
1387
1387
requested_output_channel = tracker .get_latest_input_channel ()
1388
1388
@@ -1398,7 +1398,7 @@ def _get_output_channel(
1398
1398
# otherwise use `CollectingOutputChannel`
1399
1399
return reduce (
1400
1400
lambda output_channel_created_so_far , input_channel : (
1401
- input_channel .get_output_channel () or output_channel_created_so_far
1401
+ input_channel .get_output_channel () or output_channel_created_so_far
1402
1402
),
1403
1403
matching_channels ,
1404
1404
CollectingOutputChannel (),
@@ -1445,7 +1445,7 @@ def _validate_json_training_payload(rjs: Dict) -> None:
1445
1445
1446
1446
1447
1447
def _training_payload_from_yaml (
1448
- request : Request , temp_dir : Path , file_name : Text = "data.yml"
1448
+ request : Request , temp_dir : Path , file_name : Text = "data.yml"
1449
1449
) -> Dict [Text , Any ]:
1450
1450
logger .debug ("Extracting YAML training data from request body." )
1451
1451
0 commit comments