Skip to content

Commit ec5267b

Browse files
committed
Merge branch 'main' into support-other-models-and-all-fields-pgvector
2 parents c2dfd58 + ab0b1fb commit ec5267b

File tree

5 files changed

+141
-185
lines changed

5 files changed

+141
-185
lines changed

dsp/modules/hf_client.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,16 @@ def send_hftgi_request_v00(arg, **kwargs):
116116
class HFClientVLLM(HFModel):
117117
def __init__(self, model, port, url="http://localhost", **kwargs):
118118
super().__init__(model=model, is_client=True)
119-
self.url = f"{url}:{port}"
119+
120+
if isinstance(url, list):
121+
self.urls = url
122+
123+
elif isinstance(url, str):
124+
self.urls = [f'{url}:{port}']
125+
126+
else:
127+
raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.")
128+
120129
self.headers = {"Content-Type": "application/json"}
121130

122131
def _generate(self, prompt, **kwargs):
@@ -128,9 +137,13 @@ def _generate(self, prompt, **kwargs):
128137
"max_tokens": kwargs["max_tokens"],
129138
"temperature": kwargs["temperature"],
130139
}
140+
141+
# Round robin the urls.
142+
url = self.urls.pop(0)
143+
self.urls.append(url)
131144

132145
response = send_hfvllm_request_v00(
133-
f"{self.url}/v1/completions",
146+
f"{url}/v1/completions",
134147
json=payload,
135148
headers=self.headers,
136149
)

dsp/modules/lm.py

+14-21
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def print_red(self, text: str, end: str = "\n"):
3333

3434
def inspect_history(self, n: int = 1, skip: int = 0):
3535
"""Prints the last n prompts and their completions.
36-
TODO: print the valid choice that contains filled output field instead of the first
36+
37+
TODO: print the valid choice that contains filled output field instead of the first.
3738
"""
3839
provider: str = self.provider
3940

@@ -45,23 +46,15 @@ def inspect_history(self, n: int = 1, skip: int = 0):
4546
prompt = x["prompt"]
4647

4748
if prompt != last_prompt:
48-
49-
if provider == "clarifai" or provider == "google" or provider == "claude":
50-
printed.append(
51-
(
52-
prompt,
53-
x['response'],
54-
),
55-
)
49+
if provider == "clarifai" or provider == "google":
50+
printed.append((prompt, x["response"]))
51+
elif provider == "anthropic":
52+
blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"]
53+
printed.append((prompt, blocks))
54+
elif provider == "cohere":
55+
printed.append((prompt, x["response"].generations))
5656
else:
57-
printed.append(
58-
(
59-
prompt,
60-
x["response"].generations
61-
if provider == "cohere"
62-
else x["response"]["choices"],
63-
),
64-
)
57+
printed.append((prompt, x["response"]["choices"]))
6558

6659
last_prompt = prompt
6760

@@ -79,9 +72,9 @@ def inspect_history(self, n: int = 1, skip: int = 0):
7972
if provider == "cohere":
8073
text = choices[0].text
8174
elif provider == "openai" or provider == "ollama":
82-
text = ' ' + self._get_choice_text(choices[0]).strip()
83-
elif provider == "clarifai" or provider == "claude" :
84-
text=choices
75+
text = " " + self._get_choice_text(choices[0]).strip()
76+
elif provider == "clarifai":
77+
text = choices
8578
elif provider == "google":
8679
text = choices[0].parts[0].text
8780
else:
@@ -99,6 +92,6 @@ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
9992
def copy(self, **kwargs):
10093
"""Returns a copy of the language model with the same parameters."""
10194
kwargs = {**self.kwargs, **kwargs}
102-
model = kwargs.pop('model')
95+
model = kwargs.pop("model")
10396

10497
return self.__class__(model=model, **kwargs)

0 commit comments

Comments
 (0)