Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update provider parameters, check for valid provider #2594

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update provider parameters, check for valid provider
Fix reading model list in GeminiPro
Fix  check content-type in OpenaiAPI
  • Loading branch information
hlohaus committed Jan 24, 2025
commit fd5fa8a4ebaf80084894141a1164b2da8f36d73d
6 changes: 4 additions & 2 deletions g4f/Provider/needs_auth/GeminiPro.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):

working = True
supports_message_history = True
supports_system_message = True
needs_auth = True

default_model = "gemini-1.5-pro"
Expand All @@ -39,7 +40,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
if not cls.models:
try:
response = requests.get(f"{api_base}/models?key={api_key}")
url = f"{cls.api_base if not api_base else api_base}/models"
response = requests.get(url, params={"key": api_key})
raise_for_status(response)
hlohaus marked this conversation as resolved.
Show resolved Hide resolved
data = response.json()
cls.models = [
Expand All @@ -50,7 +52,7 @@ def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
cls.models.sort()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorting the models in place can have side effects if the list is used elsewhere. It may be better to return a sorted copy using sorted(cls.models).

except Exception as e:
debug.log(e)
cls.models = cls.fallback_models
return cls.fallback_models
return cls.models

@classmethod
Expand Down
7 changes: 4 additions & 3 deletions g4f/Provider/needs_auth/OpenaiAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ async def create_async_generator(
if api_endpoint is None:
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
async with session.post(api_endpoint, json=data) as response:
hlohaus marked this conversation as resolved.
Show resolved Hide resolved
if response.headers.get("content-type", None if stream else "application/json") == "application/json":
content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
if content_type.startswith("application/json"):
data = await response.json()
cls.raise_error(data)
await raise_for_status(response)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure that raise_for_status is called after handling the JSON response and errors to avoid potential misinterpretation of response status.

Expand All @@ -122,7 +123,7 @@ async def create_async_generator(
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
return
elif response.headers.get("content-type", "text/event-stream" if stream else None) == "text/event-stream":
elif content_type.startswith("text/event-stream"):
await raise_for_status(response)
first = True
async for line in response.iter_lines():
Expand All @@ -147,7 +148,7 @@ async def create_async_generator(
break
else:
await raise_for_status(response)
hlohaus marked this conversation as resolved.
Show resolved Hide resolved
raise ResponseError(f"Not supported content-type: {response.headers.get('content-type')}")
raise ResponseError(f"Not supported content-type: {content_type}")

@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion g4f/gui/client/static/js/chat.v1.js
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
await api("conversation", {
id: message_id,
conversation_id: window.conversation_id,
conversation: conversation.data && provider in conversation.data ? conversation.data[provider] : null,
conversation: provider && conversation.data && provider in conversation.data ? conversation.data[provider] : null,
hlohaus marked this conversation as resolved.
Show resolved Hide resolved
model: model,
web_search: switchInput.checked,
provider: provider,
Expand Down
3 changes: 1 addition & 2 deletions g4f/gui/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_providers() -> dict[str, str]:
"name": provider.__name__,
"label": provider.label if hasattr(provider, "label") else provider.__name__,
"parent": getattr(provider, "parent", None),
"image": getattr(provider, "image_models", None) is not None,
"image": bool(getattr(provider, "image_models", False)),
hlohaus marked this conversation as resolved.
Show resolved Hide resolved
"vision": getattr(provider, "default_vision_model", None) is not None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous line, you might want to use getattr(provider, "default_vision_model", None) directly instead of checking if it is not None. This can simplify the expression.

"auth": provider.needs_auth,
"login_url": getattr(provider, "login_url", None),
Expand Down Expand Up @@ -157,7 +157,6 @@ def decorated_log(text: str):
**(provider_handler.get_parameters(as_json=True) if hasattr(provider_handler, "get_parameters") else {}),
"model": model,
"messages": kwargs.get("messages"),
"web_search": kwargs.get("web_search")
}
if isinstance(kwargs.get("conversation"), JsonConversation):
params["conversation"] = kwargs.get("conversation").get_dict()
Expand Down
3 changes: 3 additions & 0 deletions g4f/providers/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"api_key", "api_base", "seed", "width", "height",
"proof_token", "max_retries", "web_search",
"guidance_scale", "num_inference_steps", "randomize_seed",
"safe", "enhance", "private",
]

BASIC_PARAMETERS = {
Expand Down Expand Up @@ -61,6 +62,8 @@
"max_new_tokens": 1024,
"max_tokens": 4096,
"seed": 42,
"stop": ["stop1", "stop2"],
"tools": [],
}

class AbstractProvider(BaseProvider):
Expand Down
Loading