Skip to content

Commit

Permalink
Add AzureOpenAiAgent (#24058)
Browse files Browse the repository at this point in the history
* Add AzureOpenAiAgent

* quality

* Update src/transformers/tools/agents.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

---------

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
  • Loading branch information
sgugger and LysandreJik authored Jun 7, 2023
1 parent 5eb3d3c commit c3572e6
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/agent.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens

[[autodoc]] OpenAiAgent

### AzureOpenAiAgent

[[autodoc]] AzureOpenAiAgent

### Agent

[[autodoc]] Agent
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@
],
"tools": [
"Agent",
"AzureOpenAiAgent",
"HfAgent",
"LocalAgent",
"OpenAiAgent",
Expand Down Expand Up @@ -4410,6 +4411,7 @@
# Tools
from .tools import (
Agent,
AzureOpenAiAgent,
HfAgent,
LocalAgent,
OpenAiAgent,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


_import_structure = {
"agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"],
"agents": ["Agent", "AzureOpenAiAgent", "HfAgent", "LocalAgent", "OpenAiAgent"],
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
}

Expand All @@ -46,7 +46,7 @@
_import_structure["translation"] = ["TranslationTool"]

if TYPE_CHECKING:
from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent
from .agents import Agent, AzureOpenAiAgent, HfAgent, LocalAgent, OpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool

try:
Expand Down
126 changes: 126 additions & 0 deletions src/transformers/tools/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,132 @@ def _completion_generate(self, prompts, stop):
return [answer["text"] for answer in result["choices"]]


class AzureOpenAiAgent(Agent):
"""
Agent that uses Azure OpenAI to generate code. See the [official
documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI
model on Azure
<Tip warning={true}>
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
</Tip>
Args:
deployment_id (`str`):
The name of the deployed Azure openAI model to use.
api_key (`str`, *optional*):
The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`.
resource_name (`str`, *optional*):
The name of your Azure OpenAI Resource. If unset, will look for the environment variable
`"AZURE_OPENAI_RESOURCE_NAME"`.
api_version (`str`, *optional*, default to `"2022-12-01"`):
The API version to use for this agent.
is_chat_mode (`bool`, *optional*):
Whether you are using a completion model or a chat model (see note above, chat models won't be as
efficient). Will default to `gpt` being in the `deployment_id` or not.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`chat_prompt_template.txt` in this repo in this case.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`run_prompt_template.txt` in this repo in this case.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
from transformers import AzureOpenAiAgent
agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""

def __init__(
self,
deployment_id,
api_key=None,
resource_name=None,
api_version="2022-12-01",
is_chat_model=None,
chat_prompt_template=None,
run_prompt_template=None,
additional_tools=None,
):
if not is_openai_available():
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")

self.deployment_id = deployment_id
openai.api_type = "azure"
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
if api_key is None:
raise ValueError(
"You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with "
"`os.environ['AZURE_OPENAI_API_KEY'] = xxx."
)
else:
openai.api_key = api_key
if resource_name is None:
resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME", None)
if resource_name is None:
raise ValueError(
"You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with "
"`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx."
)
else:
openai.api_base = f"https://{resource_name}.openai.azure.com"
openai.api_version = api_version

if is_chat_model is None:
is_chat_model = "gpt" in deployment_id.lower()
self.is_chat_model = is_chat_model

super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)

def generate_many(self, prompts, stop):
if self.is_chat_model:
return [self._chat_generate(prompt, stop) for prompt in prompts]
else:
return self._completion_generate(prompts, stop)

def generate_one(self, prompt, stop):
if self.is_chat_model:
return self._chat_generate(prompt, stop)
else:
return self._completion_generate([prompt], stop)[0]

def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create(
engine=self.deployment_id,
messages=[{"role": "user", "content": prompt}],
temperature=0,
stop=stop,
)
return result["choices"][0]["message"]["content"]

def _completion_generate(self, prompts, stop):
result = openai.Completion.create(
engine=self.deployment_id,
prompt=prompts,
temperature=0,
stop=stop,
max_tokens=200,
)
return [answer["text"] for answer in result["choices"]]


class HfAgent(Agent):
"""
Agent that uses an inference endpoint to generate code.
Expand Down

0 comments on commit c3572e6

Please sign in to comment.