Skip to content

Commit

Permalink
Merge pull request #9 from VowpalWabbit/fix_embedding_w_indexes
Browse files Browse the repository at this point in the history
proper embeddings and rolling window average
  • Loading branch information
olgavrou authored Sep 1, 2023
2 parents 2b90a8a + 2c877a4 commit a9ba6a8
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 142 deletions.
12 changes: 11 additions & 1 deletion libs/langchain/langchain/chains/rl_chain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
SelectionScorer,
ToSelectFrom,
VwPolicy,
embed,
stringify_embedding,
)
from langchain.chains.rl_chain.pick_best_chain import (
PickBest,
PickBestEvent,
PickBestSelected,
)
from langchain.chains.rl_chain.pick_best_chain import PickBest


def configure_logger() -> None:
Expand All @@ -29,6 +35,8 @@ def configure_logger() -> None:

__all__ = [
"PickBest",
"PickBestEvent",
"PickBestSelected",
"Embed",
"BasedOn",
"ToSelectFrom",
Expand All @@ -37,4 +45,6 @@ def configure_logger() -> None:
"Embedder",
"Policy",
"VwPolicy",
"embed",
"stringify_embedding",
]
29 changes: 20 additions & 9 deletions libs/langchain/langchain/chains/rl_chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.chains.rl_chain.metrics import (
MetricsTrackerAverage,
MetricsTrackerRollingWindow,
)
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.prompts import (
Expand Down Expand Up @@ -98,6 +101,10 @@ def EmbedAndKeep(anything: Any) -> Any:
# helper functions


def stringify_embedding(embedding: List) -> str:
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])


def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
return [parser.parse_line(line) for line in input_str.split("\n")]

Expand Down Expand Up @@ -346,7 +353,7 @@ def log(self, event: TEvent) -> None:
selection_scorer_activated: bool = True
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
metrics: Optional[MetricsTracker] = None
metrics: Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]] = None

def __init__(
self,
Expand All @@ -357,6 +364,7 @@ def __init__(
policy: Type[Policy] = VwPolicy,
vw_logs: Optional[Union[str, os.PathLike]] = None,
metrics_step: int = -1,
metrics_window_size: int = -1,
*args: Any,
**kwargs: Any,
):
Expand All @@ -378,7 +386,12 @@ def __init__(
vw_logger=VwLogger(vw_logs),
)

self.metrics = MetricsTracker(step=metrics_step)
if metrics_window_size > 0:
self.metrics = MetricsTrackerRollingWindow(
step=metrics_step, window_size=metrics_window_size
)
else:
self.metrics = MetricsTrackerAverage(step=metrics_step)

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -523,8 +536,9 @@ def _call(
f"The selection scorer was not able to score, \
and the chain was not able to adjust to this response, error: {e}"
)
if self.metrics:
if self.metrics and score is not None:
self.metrics.on_feedback(score)

event = self._call_after_scoring_before_learning(score=score, event=event)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
Expand All @@ -547,16 +561,13 @@ def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a string or an _Embed object."""
join_char = ""
keep_str = ""
if isinstance(item, _Embed):
encoded = model.encode(item.value)
join_char = " "
encoded = stringify_embedding(model.encode(item.value))
if item.keep:
keep_str = item.value.replace(" ", "_") + " "
elif isinstance(item, str):
encoded = item.replace(" ", "_")
join_char = ""
else:
raise ValueError(f"Unsupported type {type(item)} for embedding")

Expand All @@ -566,7 +577,7 @@ def embed_string_type(
provided when embedding a string or _Embed object."
)

return {namespace: keep_str + join_char.join(map(str, encoded))}
return {namespace: keep_str + encoded}


def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
Expand Down
65 changes: 50 additions & 15 deletions libs/langchain/langchain/chains/rl_chain/metrics.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,66 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from collections import deque
from typing import TYPE_CHECKING, Dict, List, Union

if TYPE_CHECKING:
import pandas as pd


class MetricsTracker:
class MetricsTrackerAverage:
def __init__(self, step: int):
self._history: List[Dict[str, Union[int, float]]] = []
self._step: int = step
self._i: int = 0
self._num: float = 0
self._denom: float = 0
self.history: List[Dict[str, Union[int, float]]] = [{"step": 0, "score": 0}]
self.step: int = step
self.i: int = 0
self.num: float = 0
self.denom: float = 0

@property
def score(self) -> float:
return self._num / self._denom if self._denom > 0 else 0
return self.num / self.denom if self.denom > 0 else 0

def on_decision(self) -> None:
self._denom += 1
self.denom += 1

def on_feedback(self, score: Optional[float]) -> None:
self._num += score or 0
self._i += 1
if self._step > 0 and self._i % self._step == 0:
self._history.append({"step": self._i, "score": self.score})
def on_feedback(self, score: float) -> None:
self.num += score or 0
self.i += 1
if self.step > 0 and self.i % self.step == 0:
self.history.append({"step": self.i, "score": self.score})

def to_pandas(self) -> "pd.DataFrame":
import pandas as pd

return pd.DataFrame(self._history)
return pd.DataFrame(self.history)


class MetricsTrackerRollingWindow:
def __init__(self, window_size: int, step: int):
self.history: List[Dict[str, Union[int, float]]] = [{"step": 0, "score": 0}]
self.step: int = step
self.i: int = 0
self.window_size: int = window_size
self.queue: deque = deque()
self.sum: float = 0.0

@property
def score(self) -> float:
return self.sum / len(self.queue) if len(self.queue) > 0 else 0

def on_decision(self) -> None:
pass

def on_feedback(self, value: float) -> None:
self.sum += value
self.queue.append(value)
self.i += 1

if len(self.queue) > self.window_size:
old_val = self.queue.popleft()
self.sum -= old_val

if self.step > 0 and self.i % self.step == 0:
self.history.append({"step": self.i, "score": self.sum / len(self.queue)})

def to_pandas(self) -> "pd.DataFrame":
import pandas as pd

return pd.DataFrame(self.history)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain.chat_models import FakeListChatModel
from langchain.prompts.prompt import PromptTemplate

encoded_text = "[ e n c o d e d ] "
encoded_keyword = "[encoded]"


@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
Expand Down Expand Up @@ -176,15 +176,13 @@ def test_auto_embeddings_on() -> None:
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))

ctx_str_1 = "context1"
ctx_str_2 = "context2"

encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_text + " ".join(char for char in ctx_str_2)
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))

expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """ # noqa

Expand Down Expand Up @@ -262,15 +260,15 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))

ctx_str_1 = "context1"
ctx_str_2 = "context2"

encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))

expected = f"""shared |User {encoded_ctx_str_1} |User2 {ctx_str_2 + " " + encoded_ctx_str_2} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {encoded_str3} """ # noqa

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import langchain.chains.rl_chain.base as rl_chain
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain

encoded_text = "[ e n c o d e d ] "
encoded_keyword = "[encoded]"


@pytest.mark.requires("vowpal_wabbit_next")
Expand Down Expand Up @@ -80,12 +80,12 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None:
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))

ctx_str_1 = "context1"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))

named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
context = {"context": rl_chain.Embed(ctx_str_1)}
Expand All @@ -104,12 +104,12 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))

ctx_str_1 = "context1"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))

named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
Expand Down Expand Up @@ -170,14 +170,14 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))

ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))

named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
context = {
Expand All @@ -203,14 +203,14 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))

ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))

named_actions = {
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
Expand All @@ -236,14 +236,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))

ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))

named_actions = {
"action1": [
Expand All @@ -270,14 +268,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep()
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))

ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))

named_actions = {
"action1": [
Expand Down Expand Up @@ -305,11 +301,11 @@ def test_raw_features_underscored() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "this is a long string"
str1_underscored = str1.replace(" ", "_")
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))

ctx_str = "this is a long context"
ctx_str_underscored = ctx_str.replace(" ", "_")
encoded_ctx_str = encoded_text + " ".join(char for char in ctx_str)
encoded_ctx_str = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str))

# No embeddings
named_actions = {"action": [str1]}
Expand Down
Loading

0 comments on commit a9ba6a8

Please sign in to comment.