Skip to content

[lint] standardize isort and black in pyproject.toml #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .isort.cfg

This file was deleted.

2 changes: 1 addition & 1 deletion openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
organization = os.environ.get("OPENAI_ORGANIZATION")
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
api_type = os.environ.get("OPENAI_API_TYPE", "open_ai")
api_version = '2021-11-01-preview' if api_type == "azure" else None
api_version = "2021-11-01-preview" if api_type == "azure" else None
verify_ssl_certs = True # No effect. Certificates are always verified.
proxy = None
app_info = None
Expand Down
22 changes: 18 additions & 4 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from email import header
import json
import platform
import threading
import warnings
from email import header
from json import JSONDecodeError
from typing import Dict, Iterator, Optional, Tuple, Union
from urllib.parse import urlencode, urlsplit, urlunsplit
Expand Down Expand Up @@ -71,10 +71,21 @@ def parse_stream(rbody):


class APIRequestor:
def __init__(self, key=None, api_base=None, api_type=None, api_version=None, organization=None):
def __init__(
self,
key=None,
api_base=None,
api_type=None,
api_version=None,
organization=None,
):
self.api_base = api_base or openai.api_base
self.api_key = key or util.default_api_key()
self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
self.api_type = (
ApiType.from_str(api_type)
if api_type
else ApiType.from_str(openai.api_type)
)
self.api_version = api_version or openai.api_version
self.organization = organization or openai.organization

Expand Down Expand Up @@ -324,7 +335,10 @@ def _interpret_response_line(
) -> OpenAIResponse:
if rcode == 503:
raise error.ServiceUnavailableError(
"The server is overloaded or not ready yet.", rbody, rcode, headers=rheaders
"The server is overloaded or not ready yet.",
rbody,
rcode,
headers=rheaders,
)
try:
if hasattr(rbody, "decode"):
Expand Down
18 changes: 12 additions & 6 deletions openai/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from urllib.parse import quote_plus

from openai import api_requestor, error, util
import openai
from openai import api_requestor, error, util
from openai.openai_object import OpenAIObject
from openai.util import ApiType


class APIResource(OpenAIObject):
api_prefix = ""
azure_api_prefix = 'openai/deployments'
azure_api_prefix = "openai/deployments"

@classmethod
def retrieve(cls, id, api_key=None, request_id=None, **params):
Expand Down Expand Up @@ -49,22 +49,28 @@ def instance_url(self, operation=None):

if self.typed_api_type == ApiType.AZURE:
if not api_version:
raise error.InvalidRequestError("An API version is required for the Azure API type.")
raise error.InvalidRequestError(
"An API version is required for the Azure API type."
)
if not operation:
raise error.InvalidRequestError(
"The request needs an operation (eg: 'search') for the Azure OpenAI API type."
)
extn = quote_plus(id)
return "/%s/%s/%s?api-version=%s" % (self.azure_api_prefix, extn, operation, api_version)
return "/%s/%s/%s?api-version=%s" % (
self.azure_api_prefix,
extn,
operation,
api_version,
)

elif self.typed_api_type == ApiType.OPEN_AI:
base = self.class_url()
extn = quote_plus(id)
return "%s/%s" % (base, extn)

else:
raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)

raise error.InvalidAPIType("Unsupported API type %s" % self.api_type)

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
Expand Down
49 changes: 36 additions & 13 deletions openai/api_resources/abstract/engine_api_resource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydoc import apropos
import time
from pydoc import apropos
from typing import Optional
from urllib.parse import quote_plus

Expand All @@ -15,28 +15,44 @@
class EngineAPIResource(APIResource):
engine_required = True
plain_old_data = False
azure_api_prefix = 'openai/deployments'
azure_api_prefix = "openai/deployments"

def __init__(self, engine: Optional[str] = None, **kwargs):
super().__init__(engine=engine, **kwargs)

@classmethod
def class_url(cls, engine: Optional[str] = None, api_type : Optional[str] = None, api_version: Optional[str] = None):
def class_url(
cls,
engine: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
):
# Namespaces are separated in object names with periods (.) and in URLs
# with forward slashes (/), so replace the former with the latter.
base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
typed_api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
typed_api_type = (
ApiType.from_str(api_type)
if api_type
else ApiType.from_str(openai.api_type)
)
api_version = api_version or openai.api_version

if typed_api_type == ApiType.AZURE:
if not api_version:
raise error.InvalidRequestError("An API version is required for the Azure API type.")
raise error.InvalidRequestError(
"An API version is required for the Azure API type."
)
if engine is None:
raise error.InvalidRequestError(
"You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service"
)
extn = quote_plus(engine)
return "/%s/%s/%ss?api-version=%s" % (cls.azure_api_prefix, extn, base, api_version)
return "/%s/%s/%ss?api-version=%s" % (
cls.azure_api_prefix,
extn,
base,
api_version,
)

elif typed_api_type == ApiType.OPEN_AI:
if engine is None:
Expand All @@ -46,8 +62,7 @@ def class_url(cls, engine: Optional[str] = None, api_type : Optional[str] = None
return "/engines/%s/%ss" % (extn, base)

else:
raise error.InvalidAPIType('Unsupported API type %s' % api_type)

raise error.InvalidAPIType("Unsupported API type %s" % api_type)

@classmethod
def create(
Expand Down Expand Up @@ -133,23 +148,31 @@ def instance_url(self):
"id",
)

params_connector = '?'
params_connector = "?"
if self.typed_api_type == ApiType.AZURE:
api_version = self.api_version or openai.api_version
if not api_version:
raise error.InvalidRequestError("An API version is required for the Azure API type.")
raise error.InvalidRequestError(
"An API version is required for the Azure API type."
)
extn = quote_plus(id)
base = self.OBJECT_NAME.replace(".", "/")
url = "/%s/%s/%ss/%s?api-version=%s" % (self.azure_api_prefix, self.engine, base, extn, api_version)
params_connector = '&'
url = "/%s/%s/%ss/%s?api-version=%s" % (
self.azure_api_prefix,
self.engine,
base,
extn,
api_version,
)
params_connector = "&"

elif self.typed_api_type == ApiType.OPEN_AI:
base = self.class_url(self.engine, self.api_type, self.api_version)
extn = quote_plus(id)
url = "%s/%s" % (base, extn)

else:
raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)
raise error.InvalidAPIType("Unsupported API type %s" % self.api_type)

timeout = self.get("timeout")
if timeout is not None:
Expand Down
2 changes: 1 addition & 1 deletion openai/api_resources/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def search(self, **params):
elif self.typed_api_type == ApiType.OPEN_AI:
return self.request("post", self.instance_url() + "/search", params)
else:
raise InvalidAPIType('Unsupported API type %s' % self.api_type)
raise InvalidAPIType("Unsupported API type %s" % self.api_type)

def embeddings(self, **params):
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion openai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests

import openai
import openai.wandb_logger
from openai.upload_progress import BufferReader
from openai.validators import (
apply_necessary_remediation,
Expand All @@ -19,7 +20,6 @@
write_out_file,
write_out_search_file,
)
import openai.wandb_logger


class bcolors:
Expand Down
1 change: 1 addition & 0 deletions openai/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class RateLimitError(OpenAIError):
class ServiceUnavailableError(OpenAIError):
pass


class InvalidAPIType(OpenAIError):
pass

Expand Down
6 changes: 5 additions & 1 deletion openai/openai_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,11 @@ def openai_id(self):

@property
def typed_api_type(self):
return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)
return (
ApiType.from_str(self.api_type)
if self.api_type
else ApiType.from_str(openai.api_type)
)

# This class overrides __setitem__ to throw exceptions on inputs that it
# doesn't like. This can cause problems when we try to copy an object
Expand Down
20 changes: 14 additions & 6 deletions openai/tests/test_api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json

import pytest
import requests
from pytest_mock import MockerFixture

from openai import Model
from openai.api_requestor import APIRequestor


@pytest.mark.requestor
def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
# Fake out 'requests' and confirm that the X-Request-Id header is set.
Expand All @@ -27,24 +29,30 @@ def fake_request(self, *args, **kwargs):
got_request_id = got_headers.get("X-Request-Id")
assert got_request_id == fake_request_id


@pytest.mark.requestor
def test_requestor_open_ai_headers() -> None:
api_requestor = APIRequestor(key="test_key", api_type="open_ai")
headers = {"Test_Header": "Unit_Test_Header"}
headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
headers = api_requestor.request_headers(
method="get", extra=headers, request_id="test_id"
)
print(headers)
assert "Test_Header"in headers
assert "Test_Header" in headers
assert headers["Test_Header"] == "Unit_Test_Header"
assert "Authorization"in headers
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer test_key"


@pytest.mark.requestor
def test_requestor_azure_headers() -> None:
api_requestor = APIRequestor(key="test_key", api_type="azure")
headers = {"Test_Header": "Unit_Test_Header"}
headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
headers = api_requestor.request_headers(
method="get", extra=headers, request_id="test_id"
)
print(headers)
assert "Test_Header"in headers
assert "Test_Header" in headers
assert headers["Test_Header"] == "Unit_Test_Header"
assert "api-key"in headers
assert "api-key" in headers
assert headers["api-key"] == "test_key"
Loading