Skip to content

Commit

Permalink
Port changes from "use pickler alternatives" to 3.6.x
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Nov 26, 2024
1 parent e0f4e15 commit 5197074
Show file tree
Hide file tree
Showing 24 changed files with 1,606 additions and 476 deletions.
19 changes: 19 additions & 0 deletions changelog/1424.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Replace `pickle` and `joblib` with safer alternatives, e.g. `json`, `safetensors`, and `skops`, for
serializing components.

**Note**: This is a model breaking change. Please retrain your model.

If you have a custom component that inherits from one of the components listed below and modified the `persist` or
`load` method, make sure to update your code. Please contact us in case you encounter any problems.

Affected components:

- `CountVectorFeaturizer`
- `LexicalSyntacticFeaturizer`
- `LogisticRegressionClassifier`
- `SklearnIntentClassifier`
- `DIETClassifier`
- `CRFEntityExtractor`
- `TrackerFeaturizer`
- `TEDPolicy`
- `UnexpectedIntentTEDPolicy`
22 changes: 22 additions & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ sanic-cors = "~2.0.0"
sanic-jwt = "^1.6.0"
sanic-routing = "^0.7.2"
websockets = ">=10.0,<11.0"
cloudpickle = ">=1.2,<2.3"
aiohttp = ">=3.6,!=3.7.4.post0,<3.9"
questionary = ">=1.5.1,<1.11.0"
prompt-toolkit = "^3.0,<3.0.29"
Expand All @@ -135,7 +134,6 @@ tensorflow_hub = "^0.13.0"
setuptools = ">=65.5.1"
ujson = ">=1.35,<6.0"
regex = ">=2020.6,<2022.11"
joblib = ">=0.15.1,<1.3.0"
sentry-sdk = ">=0.17.0,<1.15.0"
aio-pika = ">=6.7.1,<8.2.4"
aiogram = "<2.26"
Expand All @@ -155,6 +153,8 @@ dnspython = "2.3.0"
wheel = ">=0.38.1"
certifi = ">=2023.7.22"
cryptography = ">=41.0.2"
skops = "~0.10.0"
safetensors = "~0.4.5"
[[tool.poetry.dependencies.tensorflow-io-gcs-filesystem]]
version = "==0.31"
markers = "sys_platform == 'win32'"
Expand Down
23 changes: 22 additions & 1 deletion rasa/core/featurizers/single_state_featurizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import List, Optional, Dict, Text, Set, Any

import numpy as np
import scipy.sparse
from typing import List, Optional, Dict, Text, Set, Any

from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
from rasa.nlu.extractors.extractor import EntityTagSpec
Expand Down Expand Up @@ -362,6 +363,26 @@ def encode_all_labels(
for action in domain.action_names_or_texts
]

def to_dict(self) -> Dict[str, Any]:
return {
"action_texts": self.action_texts,
"entity_tag_specs": self.entity_tag_specs,
"feature_states": self._default_feature_states,
}

@classmethod
def create_from_dict(
cls, data: Dict[str, Any]
) -> Optional["SingleStateFeaturizer"]:
if not data:
return None

featurizer = SingleStateFeaturizer()
featurizer.action_texts = data["action_texts"]
featurizer._default_feature_states = data["feature_states"]
featurizer.entity_tag_specs = data["entity_tag_specs"]
return featurizer


class IntentTokenizerSingleStateFeaturizer(SingleStateFeaturizer):
"""A SingleStateFeaturizer for use with policies that predict intent labels."""
Expand Down
133 changes: 115 additions & 18 deletions rasa/core/featurizers/tracker_featurizers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations
from pathlib import Path
from collections import defaultdict
from abc import abstractmethod
import jsonpickle
import logging

from tqdm import tqdm
import logging
from abc import abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import (
Tuple,
List,
Expand All @@ -18,25 +16,30 @@
Set,
DefaultDict,
cast,
Type,
Callable,
ClassVar,
)

import numpy as np
from tqdm import tqdm

from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
import rasa.shared.core.trackers
import rasa.shared.utils.io
from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
from rasa.shared.nlu.training_data.features import Features
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.domain import State, Domain
from rasa.shared.core.events import Event, ActionExecuted, UserUttered
from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
from rasa.shared.core.constants import (
USER,
ACTION_UNLIKELY_INTENT_NAME,
PREVIOUS_ACTION,
)
from rasa.shared.core.domain import State, Domain
from rasa.shared.core.events import Event, ActionExecuted, UserUttered
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.exceptions import RasaException
from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
from rasa.shared.nlu.training_data.features import Features
from rasa.utils.tensorflow.constants import LABEL_PAD_ID
from rasa.utils.tensorflow.model_data import ragged_array_to_ndarray

Expand Down Expand Up @@ -64,6 +67,10 @@ def __str__(self) -> Text:
class TrackerFeaturizer:
"""Base class for actual tracker featurizers."""

# Class registry to store all subclasses
_registry: ClassVar[Dict[str, Type["TrackerFeaturizer"]]] = {}
_featurizer_type: str = "TrackerFeaturizer"

def __init__(
self, state_featurizer: Optional[SingleStateFeaturizer] = None
) -> None:
Expand All @@ -74,6 +81,36 @@ def __init__(
"""
self.state_featurizer = state_featurizer

@classmethod
def register(cls, featurizer_type: str) -> Callable:
"""Decorator to register featurizer subclasses."""

def wrapper(subclass: Type["TrackerFeaturizer"]) -> Type["TrackerFeaturizer"]:
cls._registry[featurizer_type] = subclass
# Store the type identifier in the class for serialization
subclass._featurizer_type = featurizer_type
return subclass

return wrapper

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
"""Create featurizer instance from dictionary."""
featurizer_type = data.pop("type")

if featurizer_type not in cls._registry:
raise ValueError(f"Unknown featurizer type: {featurizer_type}")

# Get the correct subclass and instantiate it
subclass = cls._registry[featurizer_type]
return subclass.create_from_dict(data)

@classmethod
@abstractmethod
def create_from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
"""Each subclass must implement its own creation from dict method."""
pass

@staticmethod
def _create_states(
tracker: DialogueStateTracker,
Expand Down Expand Up @@ -465,9 +502,7 @@ def persist(self, path: Union[Text, Path]) -> None:
self.state_featurizer.entity_tag_specs = []

# noinspection PyTypeChecker
rasa.shared.utils.io.write_text_file(
str(jsonpickle.encode(self)), featurizer_file
)
rasa.shared.utils.io.dump_obj_as_json_to_file(featurizer_file, self.to_dict())

@staticmethod
def load(path: Union[Text, Path]) -> Optional[TrackerFeaturizer]:
Expand All @@ -481,7 +516,17 @@ def load(path: Union[Text, Path]) -> Optional[TrackerFeaturizer]:
"""
featurizer_file = Path(path) / FEATURIZER_FILE
if featurizer_file.is_file():
return jsonpickle.decode(rasa.shared.utils.io.read_file(featurizer_file))
data = rasa.shared.utils.io.read_json_file(featurizer_file)

if "type" not in data:
logger.error(
f"Couldn't load featurizer for policy. "
f"File '{featurizer_file}' does not contain all "
f"necessary information. 'type' is missing."
)
return None

return TrackerFeaturizer.from_dict(data)

logger.error(
f"Couldn't load featurizer for policy. "
Expand All @@ -508,7 +553,16 @@ def _remove_action_unlikely_intent_from_events(events: List[Event]) -> List[Even
)
]

def to_dict(self) -> Dict[str, Any]:
return {
"type": self.__class__._featurizer_type,
"state_featurizer": (
self.state_featurizer.to_dict() if self.state_featurizer else None
),
}


@TrackerFeaturizer.register("FullDialogueTrackerFeaturizer")
class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
"""Creates full dialogue training data for time distributed architectures.
Expand Down Expand Up @@ -646,7 +700,20 @@ def prediction_states(

return trackers_as_states

def to_dict(self) -> Dict[str, Any]:
return super().to_dict()

@classmethod
def create_from_dict(cls, data: Dict[str, Any]) -> "FullDialogueTrackerFeaturizer":
state_featurizer = SingleStateFeaturizer.create_from_dict(
data["state_featurizer"]
)
return cls(
state_featurizer,
)


@TrackerFeaturizer.register("MaxHistoryTrackerFeaturizer")
class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
"""Truncates the tracker history into `max_history` long sequences.
Expand Down Expand Up @@ -887,7 +954,25 @@ def prediction_states(

return trackers_as_states

def to_dict(self) -> Dict[str, Any]:
data = super().to_dict()
data.update(
{
"remove_duplicates": self.remove_duplicates,
"max_history": self.max_history,
}
)
return data

@classmethod
def create_from_dict(cls, data: Dict[str, Any]) -> "MaxHistoryTrackerFeaturizer":
state_featurizer = SingleStateFeaturizer.create_from_dict(
data["state_featurizer"]
)
return cls(state_featurizer, data["max_history"], data["remove_duplicates"])


@TrackerFeaturizer.register("IntentMaxHistoryTrackerFeaturizer")
class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
"""Truncates the tracker history into `max_history` long sequences.
Expand Down Expand Up @@ -1166,6 +1251,18 @@ def prediction_states(

return trackers_as_states

def to_dict(self) -> Dict[str, Any]:
return super().to_dict()

@classmethod
def create_from_dict(
cls, data: Dict[str, Any]
) -> "IntentMaxHistoryTrackerFeaturizer":
state_featurizer = SingleStateFeaturizer.create_from_dict(
data["state_featurizer"]
)
return cls(state_featurizer, data["max_history"], data["remove_duplicates"])


def _is_prev_action_unlikely_intent_in_state(state: State) -> bool:
prev_action_name = state.get(PREVIOUS_ACTION, {}).get(ACTION_NAME)
Expand Down
Loading

0 comments on commit 5197074

Please sign in to comment.