Skip to content

Commit 571ee71

Browse files
authored
Merge pull request #2 from VowpalWabbit/fixes
Dependency and import fixes
2 parents e942330 + c9e9c0e commit 571ee71

File tree

11 files changed

+295
-618
lines changed

11 files changed

+295
-618
lines changed

libs/langchain/langchain/chains/rl_chain/__init__.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
from langchain.chains.rl_chain.pick_best_chain import PickBest
1+
import logging
2+
23
from langchain.chains.rl_chain.base import (
3-
Embed,
4-
BasedOn,
5-
ToSelectFrom,
6-
SelectionScorer,
74
AutoSelectionScorer,
5+
BasedOn,
6+
Embed,
87
Embedder,
98
Policy,
9+
SelectionScorer,
10+
ToSelectFrom,
1011
VwPolicy,
1112
)
12-
13-
import logging
13+
from langchain.chains.rl_chain.pick_best_chain import PickBest
1414

1515

1616
def configure_logger():
@@ -26,3 +26,15 @@ def configure_logger():
2626

2727

2828
configure_logger()
29+
30+
__all__ = [
31+
"PickBest",
32+
"Embed",
33+
"BasedOn",
34+
"ToSelectFrom",
35+
"SelectionScorer",
36+
"AutoSelectionScorer",
37+
"Embedder",
38+
"Policy",
39+
"VwPolicy",
40+
]

libs/langchain/langchain/chains/rl_chain/base.py

Lines changed: 69 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,25 @@
22

33
import logging
44
import os
5-
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
65
from abc import ABC, abstractmethod
7-
8-
import vowpal_wabbit_next as vw
9-
from langchain.chains.rl_chain.vw_logger import VwLogger
10-
from langchain.chains.rl_chain.model_repository import ModelRepository
11-
from langchain.chains.rl_chain.metrics import MetricsTracker
12-
from langchain.prompts import BasePromptTemplate
13-
14-
from langchain.pydantic_v1 import Extra, BaseModel, root_validator
6+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
157

168
from langchain.callbacks.manager import CallbackManagerForChainRun
179
from langchain.chains.base import Chain
1810
from langchain.chains.llm import LLMChain
11+
from langchain.chains.rl_chain.metrics import MetricsTracker
12+
from langchain.chains.rl_chain.model_repository import ModelRepository
13+
from langchain.chains.rl_chain.vw_logger import VwLogger
1914
from langchain.prompts import (
15+
BasePromptTemplate,
2016
ChatPromptTemplate,
21-
SystemMessagePromptTemplate,
2217
HumanMessagePromptTemplate,
18+
SystemMessagePromptTemplate,
2319
)
20+
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
21+
22+
if TYPE_CHECKING:
23+
import vowpal_wabbit_next as vw
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -87,7 +87,7 @@ def EmbedAndKeep(anything):
8787
# helper functions
8888

8989

90-
def parse_lines(parser: vw.TextFormatParser, input_str: str) -> List[vw.Example]:
90+
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
9191
return [parser.parse_line(line) for line in input_str.split("\n")]
9292

9393

@@ -100,7 +100,8 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
100100

101101
if not to_select_from:
102102
raise ValueError(
103-
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
103+
"No variables using 'ToSelectFrom' found in the inputs. \
104+
Please include at least one variable containing a list to select from."
104105
)
105106

106107
based_on = {
@@ -113,8 +114,11 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
113114

114115

115116
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
116-
# go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if
117-
# their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
117+
"""
118+
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
119+
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
120+
""" # noqa: E501
121+
118122
next_inputs = inputs.copy()
119123
for k, v in next_inputs.items():
120124
if isinstance(v, _ToSelectFrom) or isinstance(v, _BasedOn):
@@ -173,14 +177,17 @@ def __init__(
173177
self.vw_logger = vw_logger
174178

175179
def predict(self, event: Event) -> Any:
180+
import vowpal_wabbit_next as vw
181+
176182
text_parser = vw.TextFormatParser(self.workspace)
177183
return self.workspace.predict_one(
178184
parse_lines(text_parser, self.feature_embedder.format(event))
179185
)
180186

181187
def learn(self, event: Event):
182-
vw_ex = self.feature_embedder.format(event)
188+
import vowpal_wabbit_next as vw
183189

190+
vw_ex = self.feature_embedder.format(event)
184191
text_parser = vw.TextFormatParser(self.workspace)
185192
multi_ex = parse_lines(text_parser, vw_ex)
186193
self.workspace.learn_one(multi_ex)
@@ -216,13 +223,18 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
216223
@staticmethod
217224
def get_default_system_prompt() -> SystemMessagePromptTemplate:
218225
return SystemMessagePromptTemplate.from_template(
219-
"PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
220-
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
226+
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n \
227+
You are a strict judge that is called on to rank a response based on \
228+
given criteria. You must respond with your ranking by providing a \
229+
single float within the range [0, 1], 0 being very bad \
230+
response and 1 being very good response."
221231
)
222232

223233
@staticmethod
224234
def get_default_prompt() -> ChatPromptTemplate:
225-
human_template = 'Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{llm_response}".'
235+
human_template = 'Given this based_on "{rl_chain_selected_based_on}" \
236+
as the most important attribute, rank how good or bad this text is: \
237+
"{llm_response}".'
226238
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
227239
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
228240
chat_prompt = ChatPromptTemplate.from_messages(
@@ -257,25 +269,36 @@ def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
257269
return resp
258270
except Exception as e:
259271
raise RuntimeError(
260-
f"The llm did not manage to rank the response as expected, there is always the option to try again or tweak the reward prompt. Error: {e}"
272+
f"The auto selection scorer did not manage to score the response, \
273+
there is always the option to try again or tweak the reward prompt.\
274+
Error: {e}"
261275
)
262276

263277

264278
class RLChain(Chain):
265279
"""
266-
RLChain class that utilizes the Vowpal Wabbit (VW) model for personalization.
280+
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
267281
268282
Attributes:
269-
model_loading (bool, optional): If set to True, the chain will attempt to load an existing VW model from the latest checkpoint file in the {model_save_dir} directory (current directory if none specified). If set to False, it will start training from scratch, potentially overwriting existing files. Defaults to True.
270-
large_action_spaces (bool, optional): If set to True and vw_cmd has not been specified in the constructor, it will enable large action spaces
271-
vw_cmd (List[str], optional): Advanced users can set the VW command line to whatever they want, as long as it is compatible with the Type that is specified (Type Enum)
272-
model_save_dir (str, optional): The directory to save the VW model to. Defaults to the current directory.
273-
selection_scorer (SelectionScorer): If set, the chain will check the response using the provided selection_scorer and the VW model will be updated with the result. Defaults to None.
283+
- llm_chain (Chain): Represents the underlying Language Model chain.
284+
- prompt (BasePromptTemplate): The template for the base prompt.
285+
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
286+
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
287+
- auto_embed (bool): Determines if embedding should be automatic. Default is True.
288+
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
289+
290+
Initialization Attributes:
291+
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
292+
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
293+
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
294+
- vw_cmd (List[str], optional): Command line arguments for the VW model.
295+
- policy (VwPolicy): Policy used by the chain.
296+
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
297+
- metrics_step (int): Step for the metrics tracker. Default is -1.
274298
275299
Notes:
276-
The class creates a VW model instance using the provided arguments. Before the chain object is destroyed the save_progress() function can be called. If it is called, the learned VW model is saved to a file in the current directory named `model-<checkpoint>.vw`. Checkpoints start at 1 and increment monotonically.
277-
When making predictions, VW is first called to choose action(s) which are then passed into the prompt with the key `{actions}`. After action selection, the LLM (Language Model) is called with the prompt populated by the chosen action(s), and the response is returned.
278-
"""
300+
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
301+
""" # noqa: E501
279302

280303
llm_chain: Chain
281304

@@ -303,7 +326,9 @@ def __init__(
303326
super().__init__(*args, **kwargs)
304327
if self.selection_scorer is None:
305328
logger.warning(
306-
"No response validator provided, which means that no reinforcement learning will be done in the RL chain unless update_with_delayed_score is called."
329+
"No selection scorer provided, which means that no \
330+
reinforcement learning will be done in the RL chain \
331+
unless update_with_delayed_score is called."
307332
)
308333
self.policy = policy(
309334
model_repo=ModelRepository(
@@ -343,7 +368,9 @@ def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
343368
or self.selected_based_on_input_key in inputs.keys()
344369
):
345370
raise ValueError(
346-
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."
371+
f"The rl chain does not accept '{self.selected_input_key}' \
372+
or '{self.selected_based_on_input_key}' as input keys, \
373+
they are reserved for internal use during auto reward."
347374
)
348375

349376
@abstractmethod
@@ -372,13 +399,13 @@ def update_with_delayed_score(
372399
self, score: float, event: Event, force_score=False
373400
) -> None:
374401
"""
375-
Learn will be called with the score specified and the actions/embeddings/etc stored in event
376-
402+
Updates the learned policy with the score provided.
377403
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
378-
"""
404+
""" # noqa: E501
379405
if self.selection_scorer and not force_score:
380406
raise RuntimeError(
381-
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."
407+
"The selection scorer is set, and force_score was not set to True. \
408+
Please set force_score=True to use this function."
382409
)
383410
self.metrics.on_feedback(score)
384411
self._call_after_scoring_before_learning(event=event, score=score)
@@ -387,10 +414,7 @@ def update_with_delayed_score(
387414

388415
def set_auto_embed(self, auto_embed: bool) -> None:
389416
"""
390-
Set whether the chain should auto embed the inputs or not. If set to False, the inputs will not be embedded and the user will need to embed the inputs themselves before calling run.
391-
392-
Args:
393-
auto_embed (bool): Whether the chain should auto embed the inputs or not.
417+
Sets whether the chain should auto embed the inputs or not.
394418
"""
395419
self.auto_embed = auto_embed
396420

@@ -435,7 +459,8 @@ def _call(
435459
)
436460
except Exception as e:
437461
logger.info(
438-
f"The LLM was not able to rank and the chain was not able to adjust to this response, error: {e}"
462+
f"The selection scorer was not able to score, \
463+
and the chain was not able to adjust to this response, error: {e}"
439464
)
440465
self.metrics.on_feedback(score)
441466
event = self._call_after_scoring_before_learning(score=score, event=event)
@@ -446,16 +471,7 @@ def _call(
446471

447472
def save_progress(self) -> None:
448473
"""
449-
This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file.
450-
451-
File Naming Convention:
452-
The file will be named using the pattern `model-<checkpoint>.vw`, where `<checkpoint>` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `<checkpoint>` will be the next in the sequence.
453-
454-
Example:
455-
If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`.
456-
457-
Note:
458-
Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers.
474+
This function should be called to save the state of the learned policy model.
459475
"""
460476
self.policy.save()
461477

@@ -490,7 +506,8 @@ def embed_string_type(
490506

491507
if namespace is None:
492508
raise ValueError(
493-
"The default namespace must be provided when embedding a string or _Embed object."
509+
"The default namespace must be \
510+
provided when embedding a string or _Embed object."
494511
)
495512

496513
return {namespace: keep_str + join_char.join(map(str, encoded))}
@@ -530,15 +547,15 @@ def embed(
530547
namespace: Optional[str] = None,
531548
) -> List[Dict[str, Union[str, List[str]]]]:
532549
"""
533-
Embeds the actions or context using the SentenceTransformer model
550+
Embeds the actions or context using the SentenceTransformer model (or a model that has an `encode` function)
534551
535552
Attributes:
536553
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
537554
namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided.
538555
model: (Any, required) The model to use for embedding
539556
Returns:
540557
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
541-
"""
558+
""" # noqa: E501
542559
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
543560
to_embed, str
544561
):

libs/langchain/langchain/chains/rl_chain/metrics.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
import pandas as pd
2-
from typing import Optional
1+
from typing import TYPE_CHECKING, Optional
2+
3+
if TYPE_CHECKING:
4+
import pandas as pd
35

46

57
class MetricsTracker:
@@ -23,5 +25,7 @@ def on_feedback(self, score: Optional[float]) -> None:
2325
if self._step > 0 and self._i % self._step == 0:
2426
self._history.append({"step": self._i, "score": self.score})
2527

26-
def to_pandas(self) -> pd.DataFrame:
28+
def to_pandas(self) -> "pd.DataFrame":
29+
import pandas as pd
30+
2731
return pd.DataFrame(self._history)

libs/langchain/langchain/chains/rl_chain/model_repository.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from pathlib import Path
2-
import shutil
31
import datetime
4-
import vowpal_wabbit_next as vw
5-
from typing import Union, Sequence
6-
import os
72
import glob
83
import logging
4+
import os
5+
import shutil
6+
from pathlib import Path
7+
from typing import TYPE_CHECKING, Sequence, Union
8+
9+
if TYPE_CHECKING:
10+
import vowpal_wabbit_next as vw
911

1012
logger = logging.getLogger(__name__)
1113

@@ -35,14 +37,16 @@ def get_tag(self) -> str:
3537
def has_history(self) -> bool:
3638
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
3739

38-
def save(self, workspace: vw.Workspace) -> None:
40+
def save(self, workspace: "vw.Workspace") -> None:
3941
with open(self.model_path, "wb") as f:
4042
logger.info(f"storing rl_chain model in: {self.model_path}")
4143
f.write(workspace.serialize())
4244
if self.with_history: # write history
4345
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
4446

45-
def load(self, commandline: Sequence[str]) -> vw.Workspace:
47+
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
48+
import vowpal_wabbit_next as vw
49+
4650
model_data = None
4751
if self.model_path.exists():
4852
with open(self.model_path, "rb") as f:

0 commit comments

Comments
 (0)