From 946aa8fa4a49f8331a1ad73ce2b063e26b1980d7 Mon Sep 17 00:00:00 2001 From: hallacy Date: Wed, 2 Feb 2022 13:59:23 -0800 Subject: [PATCH] [headers] Allow user provided headers in completion (#116) (#71) --- openai/api_requestor.py | 35 +++++++++++++++---- .../abstract/engine_api_resource.py | 8 ++++- openai/openai_object.py | 7 +++- openai/tests/test_endpoints.py | 6 ++++ 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/openai/api_requestor.py b/openai/api_requestor.py index cac62f5544..dd618c6e5d 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -100,8 +100,8 @@ def request( result = self.request_raw( method.lower(), url, - params, - headers, + params=params, + supplied_headers=headers, files=files, stream=stream, request_id=request_id, @@ -212,18 +212,41 @@ def request_headers( return headers + def _validate_headers( + self, supplied_headers: Optional[Dict[str, str]] + ) -> Dict[str, str]: + headers: Dict[str, str] = {} + if supplied_headers is None: + return headers + + if not isinstance(supplied_headers, dict): + raise TypeError("Headers must be a dictionary") + + for k, v in supplied_headers.items(): + if not isinstance(k, str): + raise TypeError("Header keys must be strings") + if not isinstance(v, str): + raise TypeError("Header values must be strings") + headers[k] = v + + # NOTE: It is possible to do more validation of the headers, but a request could always + # be made to the API manually with invalid headers, so we need to handle them server side. + + return headers + def request_raw( self, method, url, + *, params=None, - supplied_headers=None, + supplied_headers: Dict[str, str] = None, files=None, - stream=False, + stream: bool = False, request_id: Optional[str] = None, ) -> requests.Response: abs_url = "%s%s" % (self.api_base, url) - headers = {} + headers = self._validate_headers(supplied_headers) data = None if method == "get" or method == "delete": @@ -246,8 +269,6 @@ def request_raw( ) headers = self.request_headers(method, headers, request_id) - if supplied_headers is not None: - headers.update(supplied_headers) util.log_info("Request to OpenAI API", method=method, path=abs_url) util.log_debug("Post details", data=data, api_version=self.api_version) diff --git a/openai/api_resources/abstract/engine_api_resource.py b/openai/api_resources/abstract/engine_api_resource.py index 84904e1f80..e3592c8599 100644 --- a/openai/api_resources/abstract/engine_api_resource.py +++ b/openai/api_resources/abstract/engine_api_resource.py @@ -63,6 +63,7 @@ def create( engine = params.pop("engine", None) timeout = params.pop("timeout", None) stream = params.get("stream", False) + headers = params.pop("headers", None) if engine is None and cls.engine_required: raise error.InvalidRequestError( "Must provide an 'engine' parameter to create a %s" % cls, "engine" @@ -87,7 +88,12 @@ def create( ) url = cls.class_url(engine, api_type, api_version) response, _, api_key = requestor.request( - "post", url, params, stream=stream, request_id=request_id + "post", + url, + params=params, + headers=headers, + stream=stream, + request_id=request_id, ) if stream: diff --git a/openai/openai_object.py b/openai/openai_object.py index 3dbd2a8ed7..f785c89484 100644 --- a/openai/openai_object.py +++ b/openai/openai_object.py @@ -176,7 +176,12 @@ def request( organization=self.organization, ) response, stream, api_key = requestor.request( - method, url, params, stream=stream, headers=headers, request_id=request_id + method, + url, + params=params, + stream=stream, + headers=headers, + request_id=request_id, ) if stream: diff --git a/openai/tests/test_endpoints.py b/openai/tests/test_endpoints.py index 6ef4d2d373..80039aa995 100644 --- a/openai/tests/test_endpoints.py +++ b/openai/tests/test_endpoints.py @@ -28,3 +28,9 @@ def test_completions_multiple_prompts(): prompt=["This was a test", "This was another test"], n=5, engine="ada" ) assert len(result.choices) == 10 + + +def test_completions_model(): + result = openai.Completion.create(prompt="This was a test", n=5, model="ada") + assert len(result.choices) == 5 + assert result.model.startswith("ada:")