Skip to content

Commit 44badd0

Browse files
committed
add dependency requirements to test file
1 parent e276ae2 commit 44badd0

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

libs/langchain/langchain/chains/rl_chain/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def _call(
471471

472472
def save_progress(self) -> None:
473473
"""
474-
This function should be called to save the state of the Vowpal Wabbit model.
474+
This function should be called to save the state of the learned policy model.
475475
"""
476476
self.policy.save()
477477

libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
encoded_text = "[ e n c o d e d ] "
1010

1111

12-
@pytest.mark.requires("vowpal_wabbit_next")
12+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
1313
def setup():
1414
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
1515
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
@@ -18,7 +18,7 @@ def setup():
1818
return llm, PROMPT
1919

2020

21-
@pytest.mark.requires("vowpal_wabbit_next")
21+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
2222
def test_multiple_ToSelectFrom_throws():
2323
llm, PROMPT = setup()
2424
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@@ -31,7 +31,7 @@ def test_multiple_ToSelectFrom_throws():
3131
)
3232

3333

34-
@pytest.mark.requires("vowpal_wabbit_next")
34+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
3535
def test_missing_basedOn_from_throws():
3636
llm, PROMPT = setup()
3737
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@@ -40,7 +40,7 @@ def test_missing_basedOn_from_throws():
4040
chain.run(action=rl_chain.ToSelectFrom(actions))
4141

4242

43-
@pytest.mark.requires("vowpal_wabbit_next")
43+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
4444
def test_ToSelectFrom_not_a_list_throws():
4545
llm, PROMPT = setup()
4646
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@@ -52,7 +52,7 @@ def test_ToSelectFrom_not_a_list_throws():
5252
)
5353

5454

55-
@pytest.mark.requires("vowpal_wabbit_next")
55+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
5656
def test_update_with_delayed_score_with_auto_validator_throws():
5757
llm, PROMPT = setup()
5858
# this LLM returns a number so that the auto validator will return that
@@ -74,7 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
7474
chain.update_with_delayed_score(event=selection_metadata, score=100)
7575

7676

77-
@pytest.mark.requires("vowpal_wabbit_next")
77+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
7878
def test_update_with_delayed_score_force():
7979
llm, PROMPT = setup()
8080
# this LLM returns a number so that the auto validator will return that
@@ -98,7 +98,7 @@ def test_update_with_delayed_score_force():
9898
assert selection_metadata.selected.score == 100.0
9999

100100

101-
@pytest.mark.requires("vowpal_wabbit_next")
101+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
102102
def test_update_with_delayed_score():
103103
llm, PROMPT = setup()
104104
chain = pick_best_chain.PickBest.from_llm(
@@ -116,7 +116,7 @@ def test_update_with_delayed_score():
116116
assert selection_metadata.selected.score == 100.0
117117

118118

119-
@pytest.mark.requires("vowpal_wabbit_next")
119+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
120120
def test_user_defined_scorer():
121121
llm, PROMPT = setup()
122122

@@ -138,7 +138,7 @@ def score_response(self, inputs, llm_response: str) -> float:
138138
assert selection_metadata.selected.score == 200.0
139139

140140

141-
@pytest.mark.requires("vowpal_wabbit_next")
141+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
142142
def test_default_embeddings():
143143
llm, PROMPT = setup()
144144
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@@ -172,7 +172,7 @@ def test_default_embeddings():
172172
assert vw_str == expected
173173

174174

175-
@pytest.mark.requires("vowpal_wabbit_next")
175+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
176176
def test_default_embeddings_off():
177177
llm, PROMPT = setup()
178178
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@@ -198,7 +198,7 @@ def test_default_embeddings_off():
198198
assert vw_str == expected
199199

200200

201-
@pytest.mark.requires("vowpal_wabbit_next")
201+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
202202
def test_default_embeddings_mixed_w_explicit_user_embeddings():
203203
llm, PROMPT = setup()
204204
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@@ -233,7 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
233233
assert vw_str == expected
234234

235235

236-
@pytest.mark.requires("vowpal_wabbit_next")
236+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
237237
def test_default_no_scorer_specified():
238238
_, PROMPT = setup()
239239
chain_llm = FakeListChatModel(responses=[100])
@@ -248,7 +248,7 @@ def test_default_no_scorer_specified():
248248
assert selection_metadata.selected.score == 100.0
249249

250250

251-
@pytest.mark.requires("vowpal_wabbit_next")
251+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
252252
def test_explicitly_no_scorer():
253253
llm, PROMPT = setup()
254254
chain = pick_best_chain.PickBest.from_llm(
@@ -264,7 +264,7 @@ def test_explicitly_no_scorer():
264264
assert selection_metadata.selected.score is None
265265

266266

267-
@pytest.mark.requires("vowpal_wabbit_next")
267+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
268268
def test_auto_scorer_with_user_defined_llm():
269269
llm, PROMPT = setup()
270270
scorer_llm = FakeListChatModel(responses=[300])
@@ -283,7 +283,7 @@ def test_auto_scorer_with_user_defined_llm():
283283
assert selection_metadata.selected.score == 300.0
284284

285285

286-
@pytest.mark.requires("vowpal_wabbit_next")
286+
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
287287
def test_calling_chain_w_reserved_inputs_throws():
288288
llm, PROMPT = setup()
289289
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)

0 commit comments

Comments
 (0)