Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable union-attr mypy check and fix issues #10942

Merged
merged 23 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more fixes
  • Loading branch information
ancalita committed Feb 24, 2022
commit 5500dce0efea0d0cd2efad2efd114a93892df9ea
4 changes: 3 additions & 1 deletion rasa/core/brokers/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def _publish(self, event: Dict[Text, Any]) -> None:
)

if self.producer is not None:
self.producer.send(self.topic, value=event, key=partition_key, headers=headers)
self.producer.send(
self.topic, value=event, key=partition_key, headers=headers
)

def _close(self) -> None:
if self.producer is not None:
Expand Down
13 changes: 0 additions & 13 deletions rasa/core/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ def as_dict(self) -> Dict[Text, Any]:

def dumps(self) -> Text:
"""Return json dump of `Ticket` as dictionary."""

return json.dumps(self.as_dict())

@classmethod
def from_dict(cls, data: Dict[Text, Union[int, float]]) -> "Ticket":
"""Creates `Ticket` from dictionary."""

return cls(number=data["number"], expires=data["expires"])

def __repr__(self) -> Text:
Expand All @@ -53,13 +51,11 @@ def __init__(
@classmethod
def from_dict(cls, data: Dict[Text, Any]) -> "TicketLock":
"""Create `TicketLock` from dictionary."""

tickets = [Ticket.from_dict(json.loads(d)) for d in data.get("tickets", [])]
return cls(data.get("conversation_id"), deque(tickets))

def dumps(self) -> Text:
"""Return json dump of `TicketLock`."""

tickets = [ticket.dumps() for ticket in self.tickets]
return json.dumps(dict(conversation_id=self.conversation_id, tickets=tickets))

Expand All @@ -69,12 +65,10 @@ def is_locked(self, ticket_number: int) -> bool:
Returns:
True if `now_serving` is not equal to `ticket`.
"""

return self.now_serving != ticket_number

def issue_ticket(self, lifetime: float) -> int:
"""Issue a new ticket and return its number."""

self.remove_expired_tickets()
number = self.last_issued + 1
ticket = Ticket(number, time.time() + lifetime)
Expand All @@ -84,7 +78,6 @@ def issue_ticket(self, lifetime: float) -> int:

def remove_expired_tickets(self) -> None:
"""Remove expired tickets."""

# iterate over copy of self.tickets so we can remove items
for ticket in list(self.tickets):
if ticket.has_expired():
Expand All @@ -98,7 +91,6 @@ def last_issued(self) -> int:
Number of `Ticket` that was last added. `NO_TICKET_ISSUED` if no
tickets exist.
"""

ticket_number = self._ticket_number_for(-1)

return ticket_number if ticket_number is not None else NO_TICKET_ISSUED
Expand All @@ -110,7 +102,6 @@ def now_serving(self) -> Optional[int]:
Returns:
Number of `Ticket` that is served next. 0 if no `Ticket` exists.
"""

return self._ticket_number_for(0) or 0

def _ticket_number_for(self, ticket_index: int) -> Optional[int]:
Expand All @@ -120,7 +111,6 @@ def _ticket_number_for(self, ticket_index: int) -> Optional[int]:
Ticket number for `Ticket` with index `ticket_index`. None if there are no
tickets, or if `ticket_index` is out of bounds of `self.tickets`.
"""

self.remove_expired_tickets()

try:
Expand All @@ -130,7 +120,6 @@ def _ticket_number_for(self, ticket_index: int) -> Optional[int]:

def _ticket_for_ticket_number(self, ticket_number: int) -> Optional[Ticket]:
"""Return ticket for `ticket_number`."""

self.remove_expired_tickets()

return next((t for t in self.tickets if t.number == ticket_number), None)
Expand All @@ -141,12 +130,10 @@ def is_someone_waiting(self) -> bool:
Returns:
True if the `self.tickets` queue has length greater than 0.
"""

return len(self.tickets) > 0

def remove_ticket_for(self, ticket_number: int) -> None:
"""Remove `Ticket` for `ticket_number."""

ticket = self._ticket_for_ticket_number(ticket_number)
if ticket:
self.tickets.remove(ticket)
13 changes: 9 additions & 4 deletions rasa/shared/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def _append_current_state(self) -> None:
if self._states_for_hashing is None:
self._states_for_hashing = self.past_states_for_hashing(self.domain)
else:
if self.domain is None:
ancalita marked this conversation as resolved.
Show resolved Hide resolved
return
state = self.domain.get_active_state(self)
frozen_state = self.freeze_current_state(state)
self._states_for_hashing.append(frozen_state)
Expand All @@ -194,10 +196,13 @@ def update(self, event: Event, skip_states: bool = False) -> None:
# if `skip_states` is `True`, this function behaves exactly like the
# normal update of the `DialogueStateTracker`

if self._states_for_hashing is None and not skip_states:
# rest of this function assumes we have the previous state
# cached. let's make sure it is there.
self._states_for_hashing = self.past_states_for_hashing(self.domain)
if self._states_for_hashing is None:
if not skip_states:
ancalita marked this conversation as resolved.
Show resolved Hide resolved
# rest of this function assumes we have the previous state
# cached. let's make sure it is there.
self._states_for_hashing = self.past_states_for_hashing(self.domain)
else:
self._states_for_hashing = deque()
ancalita marked this conversation as resolved.
Show resolved Hide resolved

super().update(event)

Expand Down
22 changes: 15 additions & 7 deletions rasa/utils/tensorflow/temp_keras_modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import List, Dict, Union, Optional, Any, Generator, Tuple, Iterator
from typing import List, Dict, Union, Optional, Any, Generator, Tuple, Iterator, cast

import numpy as np

Expand Down Expand Up @@ -362,7 +362,7 @@ def fit(
self.stop_training = False
self.train_function = self.make_train_function()
self._train_counter.assign(0)
callbacks.on_train_begin()
cast(callbacks_module.CallbackList, callbacks).on_train_begin()
ancalita marked this conversation as resolved.
Show resolved Hide resolved
training_logs = None
# Handle fault-tolerance for multi-worker.
# TODO(omalleyt): Fix the ordering issues that mean this has to
Expand All @@ -373,7 +373,7 @@ def fit(
logs = None
for epoch, iterator in data_handler.enumerate_epochs():
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
cast(callbacks_module.CallbackList, callbacks).on_epoch_begin(epoch)
with data_handler.catch_stop_iteration():
for step in data_handler.steps():
with tf.profiler.experimental.Trace(
Expand All @@ -383,13 +383,17 @@ def fit(
batch_size=batch_size,
_r=1,
):
callbacks.on_train_batch_begin(step)
cast(
callbacks_module.CallbackList, callbacks
).on_train_batch_begin(step)
tmp_logs = self.train_function(iterator)
if data_handler.should_sync:
context.async_wait()
logs = tmp_logs # No error, now safe to assign to logs.
end_step = step + data_handler.step_increment
callbacks.on_train_batch_end(end_step, logs)
cast(
callbacks_module.CallbackList, callbacks
).on_train_batch_end(end_step, logs)
if self.stop_training:
break

Expand Down Expand Up @@ -439,15 +443,19 @@ def fit(
val_logs = {"val_" + name: val for name, val in val_logs.items()}
epoch_logs.update(val_logs)

callbacks.on_epoch_end(epoch, epoch_logs)
cast(callbacks_module.CallbackList, callbacks).on_epoch_end(
epoch, epoch_logs
)
training_logs = epoch_logs
if self.stop_training:
break

# If eval_data_handler exists, delete it after all epochs are done.
if getattr(self, "_eval_data_handler", None) is not None:
del self._eval_data_handler
callbacks.on_train_end(logs=training_logs)
cast(callbacks_module.CallbackList, callbacks).on_train_end(
logs=training_logs
)
return self.history


Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ disallow_untyped_decorators = True
# FIXME: working our way towards removing these
# see https://github.com/RasaHQ/rasa/pull/6470
# the list below is sorted by the number of errors for each error code, in decreasing order
disable_error_code = arg-type, assignment, var-annotated,
override, misc
disable_error_code = arg-type, assignment, var-annotated, override, misc