Skip to content

Commit

Permalink
[headers] Allow user provided headers in completion (#116) (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
hallacy authored Feb 2, 2022
1 parent 62b51ca commit 946aa8f
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 9 deletions.
35 changes: 28 additions & 7 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def request(
result = self.request_raw(
method.lower(),
url,
params,
headers,
params=params,
supplied_headers=headers,
files=files,
stream=stream,
request_id=request_id,
Expand Down Expand Up @@ -212,18 +212,41 @@ def request_headers(

return headers

def _validate_headers(
self, supplied_headers: Optional[Dict[str, str]]
) -> Dict[str, str]:
headers: Dict[str, str] = {}
if supplied_headers is None:
return headers

if not isinstance(supplied_headers, dict):
raise TypeError("Headers must be a dictionary")

for k, v in supplied_headers.items():
if not isinstance(k, str):
raise TypeError("Header keys must be strings")
if not isinstance(v, str):
raise TypeError("Header values must be strings")
headers[k] = v

# NOTE: It is possible to do more validation of the headers, but a request could always
# be made to the API manually with invalid headers, so we need to handle them server side.

return headers

def request_raw(
self,
method,
url,
*,
params=None,
supplied_headers=None,
supplied_headers: Dict[str, str] = None,
files=None,
stream=False,
stream: bool = False,
request_id: Optional[str] = None,
) -> requests.Response:
abs_url = "%s%s" % (self.api_base, url)
headers = {}
headers = self._validate_headers(supplied_headers)

data = None
if method == "get" or method == "delete":
Expand All @@ -246,8 +269,6 @@ def request_raw(
)

headers = self.request_headers(method, headers, request_id)
if supplied_headers is not None:
headers.update(supplied_headers)

util.log_info("Request to OpenAI API", method=method, path=abs_url)
util.log_debug("Post details", data=data, api_version=self.api_version)
Expand Down
8 changes: 7 additions & 1 deletion openai/api_resources/abstract/engine_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def create(
engine = params.pop("engine", None)
timeout = params.pop("timeout", None)
stream = params.get("stream", False)
headers = params.pop("headers", None)
if engine is None and cls.engine_required:
raise error.InvalidRequestError(
"Must provide an 'engine' parameter to create a %s" % cls, "engine"
Expand All @@ -87,7 +88,12 @@ def create(
)
url = cls.class_url(engine, api_type, api_version)
response, _, api_key = requestor.request(
"post", url, params, stream=stream, request_id=request_id
"post",
url,
params=params,
headers=headers,
stream=stream,
request_id=request_id,
)

if stream:
Expand Down
7 changes: 6 additions & 1 deletion openai/openai_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ def request(
organization=self.organization,
)
response, stream, api_key = requestor.request(
method, url, params, stream=stream, headers=headers, request_id=request_id
method,
url,
params=params,
stream=stream,
headers=headers,
request_id=request_id,
)

if stream:
Expand Down
6 changes: 6 additions & 0 deletions openai/tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ def test_completions_multiple_prompts():
prompt=["This was a test", "This was another test"], n=5, engine="ada"
)
assert len(result.choices) == 10


def test_completions_model():
result = openai.Completion.create(prompt="This was a test", n=5, model="ada")
assert len(result.choices) == 5
assert result.model.startswith("ada:")

0 comments on commit 946aa8f

Please sign in to comment.