Skip to content

Commit

Permalink
add missing type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
ancalita committed Apr 8, 2021
1 parent 5f1256e commit dd16809
Show file tree
Hide file tree
Showing 19 changed files with 43 additions and 37 deletions.
6 changes: 4 additions & 2 deletions rasa/core/brokers/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
import logging
from asyncio import AbstractEventLoop
from typing import Any, Dict, Optional, Text
from typing import Any, Dict, Optional, Text, Generator

from sqlalchemy.orm import Session

from rasa.core.brokers.broker import EventBroker
from rasa.utils.endpoints import EndpointConfig
Expand Down Expand Up @@ -63,7 +65,7 @@ async def from_endpoint_config(
return cls(host=broker_config.url, **broker_config.kwargs)

@contextlib.contextmanager
def session_scope(self):
def session_scope(self) -> Generator[Session, None, None]:
"""Provide a transactional scope around a series of operations."""
session = self.sessionmaker()
try:
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Callable,
Iterable,
Awaitable,
NoReturn,
NoReturn, Coroutine,
)

from rasa.cli import utils as cli_utils
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
def register(
input_channels: List["InputChannel"], app: Sanic, route: Optional[Text]
) -> None:
async def handler(*args, **kwargs):
async def handler(*args: Any, **kwargs: Any) -> None:
await app.agent.handle_message(*args, **kwargs)

for channel in input_channels:
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
self._rule_only_data = {}

@property
def featurizer(self):
def featurizer(self) -> TrackerFeaturizer:
"""Returns the policy's featurizer."""
return self.__featurizer

Expand Down
2 changes: 1 addition & 1 deletion rasa/core/policies/sklearn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _default_model() -> Any:
return LogisticRegression(solver="liblinear", multi_class="auto")

@property
def _state(self):
def _state(self) -> Dict[Text, Any]:
return {attr: getattr(self, attr) for attr in self._pickle_params}

def model_architecture(self, **kwargs: Any) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def configure_app(
rasa.core.utils.list_routes(app)

# configure async loop logging
async def configure_async_logging():
async def configure_async_logging() -> None:
if logger.isEnabledFor(logging.DEBUG):
rasa.utils.io.enable_async_loop_debugging(asyncio.get_event_loop())

Expand Down
11 changes: 6 additions & 5 deletions rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
Optional,
Text,
Union,
TYPE_CHECKING,
TYPE_CHECKING, Generator,
)

from boto3.dynamodb.conditions import Key
from pymongo.collection import Collection

import rasa.core.utils as core_utils
import rasa.shared.utils.cli
Expand Down Expand Up @@ -364,7 +365,7 @@ def _set_key_prefix(self, key_prefix: Text) -> None:
def _get_key_prefix(self) -> Text:
return self.key_prefix

def save(self, tracker, timeout=None):
def save(self, tracker: DialogueStateTracker, timeout: Any = None) -> None:
"""Saves the current conversation state"""
if self.event_broker:
self.stream_events(tracker)
Expand Down Expand Up @@ -466,7 +467,7 @@ def get_or_create_table(

return table

def save(self, tracker):
def save(self, tracker: DialogueStateTracker) -> None:
"""Saves the current conversation state."""
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -577,7 +578,7 @@ def __init__(
self._ensure_indices()

@property
def conversations(self):
def conversations(self) -> Collection:
"""Returns the current conversation."""
return self.db[self.collection]

Expand Down Expand Up @@ -991,7 +992,7 @@ def _create_database(engine: "Engine", database_name: Text) -> None:
conn.close()

@contextlib.contextmanager
def session_scope(self):
def session_scope(self) -> Generator[Session, None, None]:
"""Provide a transactional scope around a series of operations."""
session = self.sessionmaker()
try:
Expand Down
10 changes: 5 additions & 5 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import uuid
from functools import partial
from multiprocessing import Process
from typing import Any, Callable, Deque, Dict, List, Optional, Text, Tuple, Union, Set

from typing import Any, Callable, Deque, Dict, List, Optional, Text, Tuple, Union, Set, Coroutine

from sanic import Sanic, response
from sanic.exceptions import NotFound
from sanic.response import HTTPResponse
from terminaltables import AsciiTable, SingleTable
import numpy as np
from aiohttp import ClientError
Expand Down Expand Up @@ -1603,17 +1603,17 @@ def start_visualization(image_path: Text, port: int) -> None:

# noinspection PyUnusedLocal
@app.exception(NotFound)
async def ignore_404s(request, exception):
async def ignore_404s(request: Any, exception: Any) -> HTTPResponse:
return response.text("Not found", status=404)

# noinspection PyUnusedLocal
@app.route(VISUALIZATION_TEMPLATE_PATH, methods=["GET"])
def visualisation_html(request):
def visualisation_html(request: Any) -> Coroutine[Any, Any, HTTPResponse]:
return response.file(visualization.visualization_html_path())

# noinspection PyUnusedLocal
@app.route("/visualization.dot", methods=["GET"])
def visualisation_png(request):
def visualisation_png(request: Any) -> Union[Coroutine[Any, Any, HTTPResponse], HTTPResponse]:
try:
headers = {"Cache-Control": "no-cache"}
return response.file(os.path.abspath(image_path), headers=headers)
Expand Down
2 changes: 1 addition & 1 deletion rasa/nlu/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class ComponentMetaclass(type):
"""Metaclass with `name` class property."""

@property
def name(cls):
def name(cls) -> Text:
"""The name property is a function of the class - its __name__."""

return cls.__name__
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def get(self, property_name: Text, default: Any = None) -> Any:
return self.metadata.get(property_name, default)

@property
def component_classes(self):
def component_classes(self) -> List[Optional[Any]]:
if self.get("pipeline"):
return [c.get("class") for c in self.get("pipeline", [])]
else:
return []

@property
def number_of_components(self):
def number_of_components(self) -> int:
return len(self.get("pipeline", []))

def for_component(self, index: int, defaults: Any = None) -> Dict[Text, Any]:
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def set(self, prop: Text, info: Any) -> None:
def get(self, prop: Text, default: Optional[Any] = None) -> Any:
return self.data.get(prop, default)

def __eq__(self, other):
def __eq__(self, other: Any) -> Any:
if not isinstance(other, Token):
return NotImplemented
return (self.start, self.end, self.text, self.lemma) == (
Expand All @@ -50,7 +50,7 @@ def __eq__(self, other):
other.lemma,
)

def __lt__(self, other):
def __lt__(self, other: Any) -> Any:
if not isinstance(other, Token):
return NotImplemented
return (self.start, self.end, self.text, self.lemma) < (
Expand Down
2 changes: 1 addition & 1 deletion rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def ensure_loaded_agent(
`True`.
"""

def decorator(f):
def decorator(f: Callable) -> Callable:
@wraps(f)
def decorated(*args: Any, **kwargs: Any) -> Any:
# noinspection PyUnresolvedReferences
Expand Down
6 changes: 3 additions & 3 deletions rasa/shared/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def fingerprint(self) -> Text:
return rasa.shared.utils.io.get_dictionary_fingerprint(self_as_dict)

@rasa.shared.utils.common.lazy_property
def user_actions_and_forms(self):
def user_actions_and_forms(self) -> List[Text]:
"""Returns combination of user actions and forms."""

return self.user_actions + self.form_names
Expand All @@ -735,7 +735,7 @@ def num_actions(self) -> int:
return len(self.action_names_or_texts)

@rasa.shared.utils.common.lazy_property
def num_states(self):
def num_states(self) -> int:
"""Number of used input states for the action prediction."""

return len(self.input_states)
Expand Down Expand Up @@ -1537,7 +1537,7 @@ def intent_config(self, intent_name: Text) -> Dict[Text, Any]:
return self.intent_properties.get(intent_name, {})

@rasa.shared.utils.common.lazy_property
def intents(self):
def intents(self) -> Text:
return sorted(self.intent_properties.keys())

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re
from pathlib import Path
from re import Match
from typing import Dict, Text, List, Any, Union, Tuple, Optional

import rasa.shared.data
Expand Down Expand Up @@ -160,7 +161,7 @@ def _parameters_from_json_string(s: Text, line: Text) -> Dict[Text, Any]:
)

def _replace_template_variables(self, line: Text) -> Text:
def process_match(matchobject):
def process_match(matchobject: Match) -> Any:
varname = matchobject.group(1)
if varname in self.template_variables:
return self.template_variables[varname]
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def lazy_property(function: Callable) -> Any:
attr_name = "_lazy_" + function.__name__

@property
def _lazyprop(self):
def _lazyprop(self: Any) -> Any:
if not hasattr(self, attr_name):
setattr(self, attr_name, function(self))
return getattr(self, attr_name)
Expand Down
8 changes: 4 additions & 4 deletions rasa/shared/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import os
from pathlib import Path
import re
from typing import Any, Dict, List, Optional, Text, Type, Union, FrozenSet
from typing import Any, Dict, List, Optional, Text, Type, Union, FrozenSet, AnyStr
import warnings

from ruamel import yaml as yaml
from ruamel.yaml import RoundTripRepresenter, YAMLError
from ruamel.yaml.constructor import DuplicateKeyError
from ruamel.yaml.constructor import DuplicateKeyError, BaseConstructor

from rasa.shared.constants import (
DEFAULT_LOG_LEVEL,
Expand Down Expand Up @@ -290,7 +290,7 @@ def json_to_string(obj: Any, **kwargs: Any) -> Text:
def fix_yaml_loader() -> None:
"""Ensure that any string read by yaml is represented as unicode."""

def construct_yaml_str(self, node):
def construct_yaml_str(self: Any, node: Any) -> Any:
# Override the default string handling function
# to always return unicode objects
return self.construct_scalar(node)
Expand All @@ -306,7 +306,7 @@ def replace_environment_variables() -> None:
env_var_pattern = re.compile(r"^(.*)\$\{(.*)\}(.*)$")
yaml.Resolver.add_implicit_resolver("!env_var", env_var_pattern, None)

def env_var_constructor(loader, node):
def env_var_constructor(loader: BaseConstructor, node: Any) -> AnyStr:
"""Process environment variables found in the YAML."""
value = loader.construct_scalar(node)
expanded_vars = os.path.expandvars(value)
Expand Down
4 changes: 2 additions & 2 deletions rasa/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ def ensure_telemetry_enabled(f: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(f):

@wraps(f)
async def decorated_coroutine(*args, **kwargs):
async def decorated_coroutine(*args: Any, **kwargs: Any) -> Any:
if is_telemetry_enabled():
return await f(*args, **kwargs)
return None

return decorated_coroutine

@wraps(f)
def decorated(*args, **kwargs):
def decorated(*args: Any, **kwargs: Any) -> Any:
if is_telemetry_enabled():
return f(*args, **kwargs)
return None
Expand Down
2 changes: 1 addition & 1 deletion rasa/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class RepeatedLogFilter(logging.Filter):

last_log = None

def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
current_log = (
record.levelno,
record.pathname,
Expand Down
5 changes: 3 additions & 2 deletions rasa/utils/tensorflow/crf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import tensorflow as tf
from tensorflow import TensorShape

from tensorflow_addons.utils.types import TensorLike
from typeguard import typechecked
from typing import Tuple, Any
from typing import Tuple, Any, List, Union


# original code taken from
Expand Down Expand Up @@ -35,7 +36,7 @@ def state_size(self) -> int:
def output_size(self) -> int:
return self._num_tags

def build(self, input_shape):
def build(self, input_shape: Union[TensorShape, List[TensorShape]]) -> None:
super().build(input_shape)

def call(
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ warn_redundant_casts = True
warn_unused_ignores = True
disallow_untyped_calls = True
disallow_incomplete_defs = True
disallow_untyped_defs = True
# FIXME: working our way towards removing these
# see https://github.com/RasaHQ/rasa/pull/6470
# the list below is sorted by the number of errors for each error code, in decreasing order
Expand Down

0 comments on commit dd16809

Please sign in to comment.