Skip to content

Commit c3572e6

Browse files
sguggerLysandreJik
andauthored
Add AzureOpenAiAgent (#24058)
* Add AzureOpenAiAgent * quality * Update src/transformers/tools/agents.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> --------- Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
1 parent 5eb3d3c commit c3572e6

File tree

4 files changed

+134
-2
lines changed

4 files changed

+134
-2
lines changed

docs/source/en/main_classes/agent.mdx

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens
3838

3939
[[autodoc]] OpenAiAgent
4040

41+
### AzureOpenAiAgent
42+
43+
[[autodoc]] AzureOpenAiAgent
44+
4145
### Agent
4246

4347
[[autodoc]] Agent

src/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@
619619
],
620620
"tools": [
621621
"Agent",
622+
"AzureOpenAiAgent",
622623
"HfAgent",
623624
"LocalAgent",
624625
"OpenAiAgent",
@@ -4410,6 +4411,7 @@
44104411
# Tools
44114412
from .tools import (
44124413
Agent,
4414+
AzureOpenAiAgent,
44134415
HfAgent,
44144416
LocalAgent,
44154417
OpenAiAgent,

src/transformers/tools/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
_import_structure = {
27-
"agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"],
27+
"agents": ["Agent", "AzureOpenAiAgent", "HfAgent", "LocalAgent", "OpenAiAgent"],
2828
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
2929
}
3030

@@ -46,7 +46,7 @@
4646
_import_structure["translation"] = ["TranslationTool"]
4747

4848
if TYPE_CHECKING:
49-
from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent
49+
from .agents import Agent, AzureOpenAiAgent, HfAgent, LocalAgent, OpenAiAgent
5050
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
5151

5252
try:

src/transformers/tools/agents.py

+126
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,132 @@ def _completion_generate(self, prompts, stop):
453453
return [answer["text"] for answer in result["choices"]]
454454

455455

456+
class AzureOpenAiAgent(Agent):
457+
"""
458+
Agent that uses Azure OpenAI to generate code. See the [official
459+
documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI
460+
model on Azure
461+
462+
<Tip warning={true}>
463+
464+
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
465+
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
466+
467+
</Tip>
468+
469+
Args:
470+
deployment_id (`str`):
471+
The name of the deployed Azure openAI model to use.
472+
api_key (`str`, *optional*):
473+
The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`.
474+
resource_name (`str`, *optional*):
475+
The name of your Azure OpenAI Resource. If unset, will look for the environment variable
476+
`"AZURE_OPENAI_RESOURCE_NAME"`.
477+
api_version (`str`, *optional*, default to `"2022-12-01"`):
478+
The API version to use for this agent.
479+
is_chat_mode (`bool`, *optional*):
480+
Whether you are using a completion model or a chat model (see note above, chat models won't be as
481+
efficient). Will default to `gpt` being in the `deployment_id` or not.
482+
chat_prompt_template (`str`, *optional*):
483+
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
484+
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
485+
`chat_prompt_template.txt` in this repo in this case.
486+
run_prompt_template (`str`, *optional*):
487+
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
488+
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
489+
`run_prompt_template.txt` in this repo in this case.
490+
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
491+
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
492+
one of the default tools, that default tool will be overridden.
493+
494+
Example:
495+
496+
```py
497+
from transformers import AzureOpenAiAgent
498+
499+
agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
500+
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
501+
```
502+
"""
503+
504+
def __init__(
505+
self,
506+
deployment_id,
507+
api_key=None,
508+
resource_name=None,
509+
api_version="2022-12-01",
510+
is_chat_model=None,
511+
chat_prompt_template=None,
512+
run_prompt_template=None,
513+
additional_tools=None,
514+
):
515+
if not is_openai_available():
516+
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
517+
518+
self.deployment_id = deployment_id
519+
openai.api_type = "azure"
520+
if api_key is None:
521+
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
522+
if api_key is None:
523+
raise ValueError(
524+
"You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with "
525+
"`os.environ['AZURE_OPENAI_API_KEY'] = xxx."
526+
)
527+
else:
528+
openai.api_key = api_key
529+
if resource_name is None:
530+
resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME", None)
531+
if resource_name is None:
532+
raise ValueError(
533+
"You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with "
534+
"`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx."
535+
)
536+
else:
537+
openai.api_base = f"https://{resource_name}.openai.azure.com"
538+
openai.api_version = api_version
539+
540+
if is_chat_model is None:
541+
is_chat_model = "gpt" in deployment_id.lower()
542+
self.is_chat_model = is_chat_model
543+
544+
super().__init__(
545+
chat_prompt_template=chat_prompt_template,
546+
run_prompt_template=run_prompt_template,
547+
additional_tools=additional_tools,
548+
)
549+
550+
def generate_many(self, prompts, stop):
551+
if self.is_chat_model:
552+
return [self._chat_generate(prompt, stop) for prompt in prompts]
553+
else:
554+
return self._completion_generate(prompts, stop)
555+
556+
def generate_one(self, prompt, stop):
557+
if self.is_chat_model:
558+
return self._chat_generate(prompt, stop)
559+
else:
560+
return self._completion_generate([prompt], stop)[0]
561+
562+
def _chat_generate(self, prompt, stop):
563+
result = openai.ChatCompletion.create(
564+
engine=self.deployment_id,
565+
messages=[{"role": "user", "content": prompt}],
566+
temperature=0,
567+
stop=stop,
568+
)
569+
return result["choices"][0]["message"]["content"]
570+
571+
def _completion_generate(self, prompts, stop):
572+
result = openai.Completion.create(
573+
engine=self.deployment_id,
574+
prompt=prompts,
575+
temperature=0,
576+
stop=stop,
577+
max_tokens=200,
578+
)
579+
return [answer["text"] for answer in result["choices"]]
580+
581+
456582
class HfAgent(Agent):
457583
"""
458584
Agent that uses an inference endpoint to generate code.

0 commit comments

Comments
 (0)