Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provider Updates and Fixes #2570

Merged
merged 11 commits into from
Jan 15, 2025
Prev Previous commit
Next Next commit
Fix for why web_search = True didn't work
  • Loading branch information
kqlio67 committed Jan 15, 2025
commit 023f96f713b8e94b90e1984985b2caa09e3f34ac
20 changes: 11 additions & 9 deletions g4f/tools/run_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,18 @@ def iter_run_tools(
messages: Messages,
provider: Optional[str] = None,
tool_calls: Optional[list] = None,
web_search: bool = False,
**kwargs
) -> AsyncIterator:
# If web_search is True, enable safe search directly
if web_search:
try:
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"]))
except Exception as e:
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Enable provider native web search
kwargs["web_search"] = True

if tool_calls is not None:
for tool in tool_calls:
if tool.get("type") == "function":
Expand All @@ -77,14 +87,6 @@ def iter_run_tools(
raise_search_exceptions=True,
**tool["function"]["arguments"]
)
elif tool.get("function", {}).get("name") == "safe_search_tool":
tool["function"]["arguments"] = validate_arguments(tool["function"])
try:
messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], **tool["function"]["arguments"]))
except Exception as e:
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
# Enable provider native web search
kwargs["web_search"] = True
elif tool.get("function", {}).get("name") == "continue_tool":
if provider not in ("OpenaiAccount", "HuggingFace"):
last_line = messages[-1]["content"].strip().splitlines()[-1]
Expand All @@ -107,4 +109,4 @@ def on_bucket(match):
if has_bucket and isinstance(messages[-1]["content"], str):
messages[-1]["content"] += BUCKET_INSTRUCTIONS

return iter_callback(model=model, messages=messages, provider=provider, **kwargs)
return iter_callback(model=model, messages=messages, provider=provider, **kwargs)
Loading