You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
220
-
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
226
+
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n\
227
+
You are a strict judge that is called on to rank a response based on \
228
+
given criteria. You must respond with your ranking by providing a \
229
+
single float within the range [0, 1], 0 being very bad \
230
+
response and 1 being very good response."
221
231
)
222
232
223
233
@staticmethod
224
234
defget_default_prompt() ->ChatPromptTemplate:
225
-
human_template='Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{llm_response}".'
235
+
human_template='Given this based_on "{rl_chain_selected_based_on}" \
236
+
as the most important attribute, rank how good or bad this text is: \
f"The llm did not manage to rank the response as expected, there is always the option to try again or tweak the reward prompt. Error: {e}"
272
+
f"The auto selection scorer did not manage to score the response, \
273
+
there is always the option to try again or tweak the reward prompt.\
274
+
Error: {e}"
261
275
)
262
276
263
277
264
278
classRLChain(Chain):
265
279
"""
266
-
RLChain class that utilizes the Vowpal Wabbit (VW) model for personalization.
280
+
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
267
281
268
282
Attributes:
269
-
model_loading (bool, optional): If set to True, the chain will attempt to load an existing VW model from the latest checkpoint file in the {model_save_dir} directory (current directory if none specified). If set to False, it will start training from scratch, potentially overwriting existing files. Defaults to True.
270
-
large_action_spaces (bool, optional): If set to True and vw_cmd has not been specified in the constructor, it will enable large action spaces
271
-
vw_cmd (List[str], optional): Advanced users can set the VW command line to whatever they want, as long as it is compatible with the Type that is specified (Type Enum)
272
-
model_save_dir (str, optional): The directory to save the VW model to. Defaults to the current directory.
273
-
selection_scorer (SelectionScorer): If set, the chain will check the response using the provided selection_scorer and the VW model will be updated with the result. Defaults to None.
283
+
- llm_chain (Chain): Represents the underlying Language Model chain.
284
+
- prompt (BasePromptTemplate): The template for the base prompt.
285
+
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
286
+
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
287
+
- auto_embed (bool): Determines if embedding should be automatic. Default is True.
288
+
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
289
+
290
+
Initialization Attributes:
291
+
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
292
+
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
293
+
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
294
+
- vw_cmd (List[str], optional): Command line arguments for the VW model.
295
+
- policy (VwPolicy): Policy used by the chain.
296
+
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
297
+
- metrics_step (int): Step for the metrics tracker. Default is -1.
274
298
275
299
Notes:
276
-
The class creates a VW model instance using the provided arguments. Before the chain object is destroyed the save_progress() function can be called. If it is called, the learned VW model is saved to a file in the current directory named `model-<checkpoint>.vw`. Checkpoints start at 1 and increment monotonically.
277
-
When making predictions, VW is first called to choose action(s) which are then passed into the prompt with the key `{actions}`. After action selection, the LLM (Language Model) is called with the prompt populated by the chosen action(s), and the response is returned.
278
-
"""
300
+
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
301
+
"""# noqa: E501
279
302
280
303
llm_chain: Chain
281
304
@@ -303,7 +326,9 @@ def __init__(
303
326
super().__init__(*args, **kwargs)
304
327
ifself.selection_scorerisNone:
305
328
logger.warning(
306
-
"No response validator provided, which means that no reinforcement learning will be done in the RL chain unless update_with_delayed_score is called."
329
+
"No selection scorer provided, which means that no \
330
+
reinforcement learning will be done in the RL chain \
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."
371
+
f"The rl chain does not accept '{self.selected_input_key}' \
372
+
or '{self.selected_based_on_input_key}' as input keys, \
373
+
they are reserved for internal use during auto reward."
Set whether the chain should auto embed the inputs or not. If set to False, the inputs will not be embedded and the user will need to embed the inputs themselves before calling run.
391
-
392
-
Args:
393
-
auto_embed (bool): Whether the chain should auto embed the inputs or not.
417
+
Sets whether the chain should auto embed the inputs or not.
394
418
"""
395
419
self.auto_embed=auto_embed
396
420
@@ -435,7 +459,8 @@ def _call(
435
459
)
436
460
exceptExceptionase:
437
461
logger.info(
438
-
f"The LLM was not able to rank and the chain was not able to adjust to this response, error: {e}"
462
+
f"The selection scorer was not able to score, \
463
+
and the chain was not able to adjust to this response, error: {e}"
This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file.
450
-
451
-
File Naming Convention:
452
-
The file will be named using the pattern `model-<checkpoint>.vw`, where `<checkpoint>` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `<checkpoint>` will be the next in the sequence.
453
-
454
-
Example:
455
-
If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`.
456
-
457
-
Note:
458
-
Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers.
474
+
This function should be called to save the state of the learned policy model.
459
475
"""
460
476
self.policy.save()
461
477
@@ -490,7 +506,8 @@ def embed_string_type(
490
506
491
507
ifnamespaceisNone:
492
508
raiseValueError(
493
-
"The default namespace must be provided when embedding a string or _Embed object."
509
+
"The default namespace must be \
510
+
provided when embedding a string or _Embed object."
Embeds the actions or context using the SentenceTransformer model
550
+
Embeds the actions or context using the SentenceTransformer model (or a model that has an `encode` function)
534
551
535
552
Attributes:
536
553
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
537
554
namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided.
538
555
model: (Any, required) The model to use for embedding
539
556
Returns:
540
557
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
541
-
"""
558
+
"""# noqa: E501
542
559
if (isinstance(to_embed, _Embed) andisinstance(to_embed.value, str)) orisinstance(
0 commit comments