From 2c877a4a3400b27e1e698c4736c48d3e0bf7372c Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 31 Aug 2023 20:14:41 -0400 Subject: [PATCH] proper embeddings and rolling window average --- .../langchain/chains/rl_chain/__init__.py | 12 +- .../langchain/chains/rl_chain/base.py | 29 ++-- .../langchain/chains/rl_chain/metrics.py | 65 ++++++-- .../rl_chain/test_pick_best_chain_call.py | 22 ++- .../rl_chain/test_pick_best_text_embedder.py | 58 ++++---- .../rl_chain/test_rl_chain_base_embedder.py | 140 +++++++++--------- 6 files changed, 184 insertions(+), 142 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/__init__.py b/libs/langchain/langchain/chains/rl_chain/__init__.py index 6d5cfc3e29c78..3a14861bd7f10 100644 --- a/libs/langchain/langchain/chains/rl_chain/__init__.py +++ b/libs/langchain/langchain/chains/rl_chain/__init__.py @@ -9,8 +9,14 @@ SelectionScorer, ToSelectFrom, VwPolicy, + embed, + stringify_embedding, +) +from langchain.chains.rl_chain.pick_best_chain import ( + PickBest, + PickBestEvent, + PickBestSelected, ) -from langchain.chains.rl_chain.pick_best_chain import PickBest def configure_logger() -> None: @@ -29,6 +35,8 @@ def configure_logger() -> None: __all__ = [ "PickBest", + "PickBestEvent", + "PickBestSelected", "Embed", "BasedOn", "ToSelectFrom", @@ -37,4 +45,6 @@ def configure_logger() -> None: "Embedder", "Policy", "VwPolicy", + "embed", + "stringify_embedding", ] diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index d08200c7096ce..6e01bb5063a78 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -19,7 +19,10 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.chains.rl_chain.metrics import MetricsTracker +from langchain.chains.rl_chain.metrics import ( + MetricsTrackerAverage, + MetricsTrackerRollingWindow, +) from langchain.chains.rl_chain.model_repository import ModelRepository from langchain.chains.rl_chain.vw_logger import VwLogger from langchain.prompts import ( @@ -98,6 +101,10 @@ def EmbedAndKeep(anything: Any) -> Any: # helper functions +def stringify_embedding(embedding: List) -> str: + return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)]) + + def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]: return [parser.parse_line(line) for line in input_str.split("\n")] @@ -346,7 +353,7 @@ def log(self, event: TEvent) -> None: selection_scorer_activated: bool = True selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" - metrics: Optional[MetricsTracker] = None + metrics: Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]] = None def __init__( self, @@ -357,6 +364,7 @@ def __init__( policy: Type[Policy] = VwPolicy, vw_logs: Optional[Union[str, os.PathLike]] = None, metrics_step: int = -1, + metrics_window_size: int = -1, *args: Any, **kwargs: Any, ): @@ -378,7 +386,12 @@ def __init__( vw_logger=VwLogger(vw_logs), ) - self.metrics = MetricsTracker(step=metrics_step) + if metrics_window_size > 0: + self.metrics = MetricsTrackerRollingWindow( + step=metrics_step, window_size=metrics_window_size + ) + else: + self.metrics = MetricsTrackerAverage(step=metrics_step) class Config: """Configuration for this pydantic object.""" @@ -523,8 +536,9 @@ def _call( f"The selection scorer was not able to score, \ and the chain was not able to adjust to this response, error: {e}" ) - if self.metrics: + if self.metrics and score is not None: self.metrics.on_feedback(score) + event = self._call_after_scoring_before_learning(score=score, event=event) self.active_policy.learn(event=event) self.active_policy.log(event=event) @@ -547,16 +561,13 @@ def embed_string_type( item: Union[str, _Embed], model: Any, namespace: Optional[str] = None ) -> Dict[str, Union[str, List[str]]]: """Helper function to embed a string or an _Embed object.""" - join_char = "" keep_str = "" if isinstance(item, _Embed): - encoded = model.encode(item.value) - join_char = " " + encoded = stringify_embedding(model.encode(item.value)) if item.keep: keep_str = item.value.replace(" ", "_") + " " elif isinstance(item, str): encoded = item.replace(" ", "_") - join_char = "" else: raise ValueError(f"Unsupported type {type(item)} for embedding") @@ -566,7 +577,7 @@ def embed_string_type( provided when embedding a string or _Embed object." ) - return {namespace: keep_str + join_char.join(map(str, encoded))} + return {namespace: keep_str + encoded} def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]: diff --git a/libs/langchain/langchain/chains/rl_chain/metrics.py b/libs/langchain/langchain/chains/rl_chain/metrics.py index 4d6306f776013..4bd65da3ae537 100644 --- a/libs/langchain/langchain/chains/rl_chain/metrics.py +++ b/libs/langchain/langchain/chains/rl_chain/metrics.py @@ -1,31 +1,66 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from collections import deque +from typing import TYPE_CHECKING, Dict, List, Union if TYPE_CHECKING: import pandas as pd -class MetricsTracker: +class MetricsTrackerAverage: def __init__(self, step: int): - self._history: List[Dict[str, Union[int, float]]] = [] - self._step: int = step - self._i: int = 0 - self._num: float = 0 - self._denom: float = 0 + self.history: List[Dict[str, Union[int, float]]] = [{"step": 0, "score": 0}] + self.step: int = step + self.i: int = 0 + self.num: float = 0 + self.denom: float = 0 @property def score(self) -> float: - return self._num / self._denom if self._denom > 0 else 0 + return self.num / self.denom if self.denom > 0 else 0 def on_decision(self) -> None: - self._denom += 1 + self.denom += 1 - def on_feedback(self, score: Optional[float]) -> None: - self._num += score or 0 - self._i += 1 - if self._step > 0 and self._i % self._step == 0: - self._history.append({"step": self._i, "score": self.score}) + def on_feedback(self, score: float) -> None: + self.num += score or 0 + self.i += 1 + if self.step > 0 and self.i % self.step == 0: + self.history.append({"step": self.i, "score": self.score}) def to_pandas(self) -> "pd.DataFrame": import pandas as pd - return pd.DataFrame(self._history) + return pd.DataFrame(self.history) + + +class MetricsTrackerRollingWindow: + def __init__(self, window_size: int, step: int): + self.history: List[Dict[str, Union[int, float]]] = [{"step": 0, "score": 0}] + self.step: int = step + self.i: int = 0 + self.window_size: int = window_size + self.queue: deque = deque() + self.sum: float = 0.0 + + @property + def score(self) -> float: + return self.sum / len(self.queue) if len(self.queue) > 0 else 0 + + def on_decision(self) -> None: + pass + + def on_feedback(self, value: float) -> None: + self.sum += value + self.queue.append(value) + self.i += 1 + + if len(self.queue) > self.window_size: + old_val = self.queue.popleft() + self.sum -= old_val + + if self.step > 0 and self.i % self.step == 0: + self.history.append({"step": self.i, "score": self.sum / len(self.queue)}) + + def to_pandas(self) -> "pd.DataFrame": + import pandas as pd + + return pd.DataFrame(self.history) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py index d4576ce2540bb..7bfa5ad5506d1 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -8,7 +8,7 @@ from langchain.chat_models import FakeListChatModel from langchain.prompts.prompt import PromptTemplate -encoded_text = "[ e n c o d e d ] " +encoded_keyword = "[encoded]" @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @@ -176,15 +176,13 @@ def test_auto_embeddings_on() -> None: str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = encoded_text + " ".join(char for char in str1) - encoded_str2 = encoded_text + " ".join(char for char in str2) - encoded_str3 = encoded_text + " ".join(char for char in str3) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) ctx_str_1 = "context1" - ctx_str_2 = "context2" - encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1) - encoded_text + " ".join(char for char in ctx_str_2) + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa @@ -262,15 +260,15 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None: str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = encoded_text + " ".join(char for char in str1) - encoded_str2 = encoded_text + " ".join(char for char in str2) - encoded_str3 = encoded_text + " ".join(char for char in str3) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1) - encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2) + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) expected = f"""shared |User {encoded_ctx_str_1} |User2 {ctx_str_2 + " " + encoded_ctx_str_2} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {encoded_str3} """ # noqa diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py index c49bacac6085c..8683e3b0e547e 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py @@ -4,7 +4,7 @@ import langchain.chains.rl_chain.base as rl_chain import langchain.chains.rl_chain.pick_best_chain as pick_best_chain -encoded_text = "[ e n c o d e d ] " +encoded_keyword = "[encoded]" @pytest.mark.requires("vowpal_wabbit_next") @@ -80,12 +80,12 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = encoded_text + " ".join(char for char in str1) - encoded_str2 = encoded_text + " ".join(char for char in str2) - encoded_str3 = encoded_text + " ".join(char for char in str3) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) ctx_str_1 = "context1" - encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1) + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) named_actions = {"action1": rl_chain.Embed([str1, str2, str3])} context = {"context": rl_chain.Embed(ctx_str_1)} @@ -104,12 +104,12 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = encoded_text + " ".join(char for char in str1) - encoded_str2 = encoded_text + " ".join(char for char in str2) - encoded_str3 = encoded_text + " ".join(char for char in str3) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) ctx_str_1 = "context1" - encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1) + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])} context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)} @@ -170,14 +170,14 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = encoded_text + " ".join(char for char in str1) - encoded_str2 = encoded_text + " ".join(char for char in str2) - encoded_str3 = encoded_text + " ".join(char for char in str3) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1) - encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2) + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])} context = { @@ -203,14 +203,14 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = encoded_text + " ".join(char for char in str1) - encoded_str2 = encoded_text + " ".join(char for char in str2) - encoded_str3 = encoded_text + " ".join(char for char in str3) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1) - encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2) + encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1)) + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) named_actions = { "action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3]) @@ -236,14 +236,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = encoded_text + " ".join(char for char in str1) - encoded_text + " ".join(char for char in str2) - encoded_str3 = encoded_text + " ".join(char for char in str3) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_text + " ".join(char for char in ctx_str_1) - encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2) + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) named_actions = { "action1": [ @@ -270,14 +268,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() str1 = "0" str2 = "1" str3 = "2" - encoded_str1 = encoded_text + " ".join(char for char in str1) - encoded_text + " ".join(char for char in str2) - encoded_str3 = encoded_text + " ".join(char for char in str3) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) + encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3)) ctx_str_1 = "context1" ctx_str_2 = "context2" - encoded_text + " ".join(char for char in ctx_str_1) - encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2) + encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2)) named_actions = { "action1": [ @@ -305,11 +301,11 @@ def test_raw_features_underscored() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "this is a long string" str1_underscored = str1.replace(" ", "_") - encoded_str1 = encoded_text + " ".join(char for char in str1) + encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1)) ctx_str = "this is a long context" ctx_str_underscored = ctx_str.replace(" ", "_") - encoded_ctx_str = encoded_text + " ".join(char for char in ctx_str) + encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str)) # No embeddings named_actions = {"action": [str1]} diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py index bd0cc584ef117..1928eb26c606b 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py @@ -5,7 +5,7 @@ import langchain.chains.rl_chain.base as base -encoded_text = "[ e n c o d e d ] " +encoded_keyword = "[encoded]" @pytest.mark.requires("vowpal_wabbit_next") @@ -17,12 +17,10 @@ def test_simple_context_str_no_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_simple_context_str_w_emb() -> None: str1 = "test" - encoded_str1 = " ".join(char for char in str1) - expected = [{"a_namespace": encoded_text + encoded_str1}] + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + expected = [{"a_namespace": encoded_str1}] assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected - expected_embed_and_keep = [ - {"a_namespace": str1 + " " + encoded_text + encoded_str1} - ] + expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}] assert ( base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace") == expected_embed_and_keep @@ -33,14 +31,14 @@ def test_simple_context_str_w_emb() -> None: def test_simple_context_str_w_nested_emb() -> None: # nested embeddings, innermost wins str1 = "test" - encoded_str1 = " ".join(char for char in str1) - expected = [{"a_namespace": encoded_text + encoded_str1}] + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + expected = [{"a_namespace": encoded_str1}] assert ( base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace") == expected ) - expected2 = [{"a_namespace": str1 + " " + encoded_text + encoded_str1}] + expected2 = [{"a_namespace": str1 + " " + encoded_str1}] assert ( base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace") == expected2 @@ -56,12 +54,10 @@ def test_context_w_namespace_no_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_context_w_namespace_w_emb() -> None: str1 = "test" - encoded_str1 = " ".join(char for char in str1) - expected = [{"test_namespace": encoded_text + encoded_str1}] + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + expected = [{"test_namespace": encoded_str1}] assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected - expected_embed_and_keep = [ - {"test_namespace": str1 + " " + encoded_text + encoded_str1} - ] + expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}] assert ( base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder()) == expected_embed_and_keep @@ -71,12 +67,10 @@ def test_context_w_namespace_w_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_context_w_namespace_w_emb2() -> None: str1 = "test" - encoded_str1 = " ".join(char for char in str1) - expected = [{"test_namespace": encoded_text + encoded_str1}] + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + expected = [{"test_namespace": encoded_str1}] assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected - expected_embed_and_keep = [ - {"test_namespace": str1 + " " + encoded_text + encoded_str1} - ] + expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}] assert ( base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder()) == expected_embed_and_keep @@ -87,10 +81,8 @@ def test_context_w_namespace_w_emb2() -> None: def test_context_w_namespace_w_some_emb() -> None: str1 = "test1" str2 = "test2" - encoded_str2 = " ".join(char for char in str2) - expected = [ - {"test_namespace": str1, "test_namespace2": encoded_text + encoded_str2} - ] + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}] assert ( base.embed( {"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder() @@ -100,7 +92,7 @@ def test_context_w_namespace_w_some_emb() -> None: expected_embed_and_keep = [ { "test_namespace": str1, - "test_namespace2": str2 + " " + encoded_text + encoded_str2, + "test_namespace2": str2 + " " + encoded_str2, } ] assert ( @@ -127,22 +119,22 @@ def test_simple_action_strlist_w_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str1 = " ".join(char for char in str1) - encoded_str2 = " ".join(char for char in str2) - encoded_str3 = " ".join(char for char in str3) + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) expected = [ - {"a_namespace": encoded_text + encoded_str1}, - {"a_namespace": encoded_text + encoded_str2}, - {"a_namespace": encoded_text + encoded_str3}, + {"a_namespace": encoded_str1}, + {"a_namespace": encoded_str2}, + {"a_namespace": encoded_str3}, ] assert ( base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace") == expected ) expected_embed_and_keep = [ - {"a_namespace": str1 + " " + encoded_text + encoded_str1}, - {"a_namespace": str2 + " " + encoded_text + encoded_str2}, - {"a_namespace": str3 + " " + encoded_text + encoded_str3}, + {"a_namespace": str1 + " " + encoded_str1}, + {"a_namespace": str2 + " " + encoded_str2}, + {"a_namespace": str3 + " " + encoded_str3}, ] assert ( base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace") @@ -155,12 +147,12 @@ def test_simple_action_strlist_w_some_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str2 = " ".join(char for char in str2) - encoded_str3 = " ".join(char for char in str3) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) expected = [ {"a_namespace": str1}, - {"a_namespace": encoded_text + encoded_str2}, - {"a_namespace": encoded_text + encoded_str3}, + {"a_namespace": encoded_str2}, + {"a_namespace": encoded_str3}, ] assert ( base.embed( @@ -170,8 +162,8 @@ def test_simple_action_strlist_w_some_emb() -> None: ) expected_embed_and_keep = [ {"a_namespace": str1}, - {"a_namespace": str2 + " " + encoded_text + encoded_str2}, - {"a_namespace": str3 + " " + encoded_text + encoded_str3}, + {"a_namespace": str2 + " " + encoded_str2}, + {"a_namespace": str3 + " " + encoded_str3}, ] assert ( base.embed( @@ -211,13 +203,13 @@ def test_action_w_namespace_w_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str1 = " ".join(char for char in str1) - encoded_str2 = " ".join(char for char in str2) - encoded_str3 = " ".join(char for char in str3) + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) expected = [ - {"test_namespace": encoded_text + encoded_str1}, - {"test_namespace": encoded_text + encoded_str2}, - {"test_namespace": encoded_text + encoded_str3}, + {"test_namespace": encoded_str1}, + {"test_namespace": encoded_str2}, + {"test_namespace": encoded_str3}, ] assert ( base.embed( @@ -231,9 +223,9 @@ def test_action_w_namespace_w_emb() -> None: == expected ) expected_embed_and_keep = [ - {"test_namespace": str1 + " " + encoded_text + encoded_str1}, - {"test_namespace": str2 + " " + encoded_text + encoded_str2}, - {"test_namespace": str3 + " " + encoded_text + encoded_str3}, + {"test_namespace": str1 + " " + encoded_str1}, + {"test_namespace": str2 + " " + encoded_str2}, + {"test_namespace": str3 + " " + encoded_str3}, ] assert ( base.embed( @@ -253,13 +245,13 @@ def test_action_w_namespace_w_emb2() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str1 = " ".join(char for char in str1) - encoded_str2 = " ".join(char for char in str2) - encoded_str3 = " ".join(char for char in str3) + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) expected = [ - {"test_namespace1": encoded_text + encoded_str1}, - {"test_namespace2": encoded_text + encoded_str2}, - {"test_namespace3": encoded_text + encoded_str3}, + {"test_namespace1": encoded_str1}, + {"test_namespace2": encoded_str2}, + {"test_namespace3": encoded_str3}, ] assert ( base.embed( @@ -275,9 +267,9 @@ def test_action_w_namespace_w_emb2() -> None: == expected ) expected_embed_and_keep = [ - {"test_namespace1": str1 + " " + encoded_text + encoded_str1}, - {"test_namespace2": str2 + " " + encoded_text + encoded_str2}, - {"test_namespace3": str3 + " " + encoded_text + encoded_str3}, + {"test_namespace1": str1 + " " + encoded_str1}, + {"test_namespace2": str2 + " " + encoded_str2}, + {"test_namespace3": str3 + " " + encoded_str3}, ] assert ( base.embed( @@ -299,12 +291,12 @@ def test_action_w_namespace_w_some_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str2 = " ".join(char for char in str2) - encoded_str3 = " ".join(char for char in str3) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) expected = [ {"test_namespace": str1}, - {"test_namespace": encoded_text + encoded_str2}, - {"test_namespace": encoded_text + encoded_str3}, + {"test_namespace": encoded_str2}, + {"test_namespace": encoded_str3}, ] assert ( base.embed( @@ -319,8 +311,8 @@ def test_action_w_namespace_w_some_emb() -> None: ) expected_embed_and_keep = [ {"test_namespace": str1}, - {"test_namespace": str2 + " " + encoded_text + encoded_str2}, - {"test_namespace": str3 + " " + encoded_text + encoded_str3}, + {"test_namespace": str2 + " " + encoded_str2}, + {"test_namespace": str3 + " " + encoded_str3}, ] assert ( base.embed( @@ -340,13 +332,13 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None: str1 = "test1" str2 = "test2" str3 = "test3" - encoded_str1 = " ".join(char for char in str1) - encoded_str2 = " ".join(char for char in str2) - encoded_str3 = " ".join(char for char in str3) + encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1)) + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3)) expected = [ - {"test_namespace": encoded_text + encoded_str1, "test_namespace2": str1}, - {"test_namespace": encoded_text + encoded_str2, "test_namespace2": str2}, - {"test_namespace": encoded_text + encoded_str3, "test_namespace2": str3}, + {"test_namespace": encoded_str1, "test_namespace2": str1}, + {"test_namespace": encoded_str2, "test_namespace2": str2}, + {"test_namespace": encoded_str3, "test_namespace2": str3}, ] assert ( base.embed( @@ -361,15 +353,15 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None: ) expected_embed_and_keep = [ { - "test_namespace": str1 + " " + encoded_text + encoded_str1, + "test_namespace": str1 + " " + encoded_str1, "test_namespace2": str1, }, { - "test_namespace": str2 + " " + encoded_text + encoded_str2, + "test_namespace": str2 + " " + encoded_str2, "test_namespace2": str2, }, { - "test_namespace": str3 + " " + encoded_text + encoded_str3, + "test_namespace": str3 + " " + encoded_str3, "test_namespace2": str3, }, ] @@ -398,8 +390,8 @@ def test_one_namespace_w_list_of_features_no_emb() -> None: def test_one_namespace_w_list_of_features_w_some_emb() -> None: str1 = "test1" str2 = "test2" - encoded_str2 = " ".join(char for char in str2) - expected = [{"test_namespace": [str1, encoded_text + encoded_str2]}] + encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2)) + expected = [{"test_namespace": [str1, encoded_str2]}] assert ( base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder()) == expected