Skip to content

Commit

Permalink
Fix the stops_at argument to LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 7, 2023
1 parent 1645a4e commit 88ec4ff
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 0 additions & 2 deletions outlines/text/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def __call__(self, prompt, stops_at=None, name=None):
"""
res = super().__call__(prompt)

self.stops_at = stops_at

if name is not None:
res.name = name

Expand Down
9 changes: 5 additions & 4 deletions outlines/text/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class OpenAI(LanguageModel):
"""

def __init__(self, model: str):
def __init__(self, model: str, stops_at=None):
"""Initialize the OpenAI model."""

try:
Expand All @@ -32,13 +32,14 @@ def __init__(self, model: str):
if model not in available_model_names:
raise OSError(f"{model} is not a valid OpenAI model name.")

if stops_at is not None and len(stops_at) > 4:
raise Exception("OpenAI's API does not accept more than 4 stop sequences.")
self.stops_at = stops_at

super().__init__(name=f"OpenAI {model}")
self.model = model

def perform(self, prompt):
if self.stops_at is not None and len(self.stops_at) > 4:
raise Exception("OpenAI's API does not accept more than 4 stop sequences.")

try:
resp = openai.Completion.create(
model=self.model,
Expand Down

0 comments on commit 88ec4ff

Please sign in to comment.