@@ -453,6 +453,132 @@ def _completion_generate(self, prompts, stop):
453
453
return [answer ["text" ] for answer in result ["choices" ]]
454
454
455
455
456
+ class AzureOpenAiAgent (Agent ):
457
+ """
458
+ Agent that uses Azure OpenAI to generate code. See the [official
459
+ documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI
460
+ model on Azure
461
+
462
+ <Tip warning={true}>
463
+
464
+ The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
465
+ `"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
466
+
467
+ </Tip>
468
+
469
+ Args:
470
+ deployment_id (`str`):
471
+ The name of the deployed Azure openAI model to use.
472
+ api_key (`str`, *optional*):
473
+ The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`.
474
+ resource_name (`str`, *optional*):
475
+ The name of your Azure OpenAI Resource. If unset, will look for the environment variable
476
+ `"AZURE_OPENAI_RESOURCE_NAME"`.
477
+ api_version (`str`, *optional*, default to `"2022-12-01"`):
478
+ The API version to use for this agent.
479
+ is_chat_mode (`bool`, *optional*):
480
+ Whether you are using a completion model or a chat model (see note above, chat models won't be as
481
+ efficient). Will default to `gpt` being in the `deployment_id` or not.
482
+ chat_prompt_template (`str`, *optional*):
483
+ Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
484
+ actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
485
+ `chat_prompt_template.txt` in this repo in this case.
486
+ run_prompt_template (`str`, *optional*):
487
+ Pass along your own prompt if you want to override the default template for the `run` method. Can be the
488
+ actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
489
+ `run_prompt_template.txt` in this repo in this case.
490
+ additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
491
+ Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
492
+ one of the default tools, that default tool will be overridden.
493
+
494
+ Example:
495
+
496
+ ```py
497
+ from transformers import AzureOpenAiAgent
498
+
499
+ agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
500
+ agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
501
+ ```
502
+ """
503
+
504
+ def __init__ (
505
+ self ,
506
+ deployment_id ,
507
+ api_key = None ,
508
+ resource_name = None ,
509
+ api_version = "2022-12-01" ,
510
+ is_chat_model = None ,
511
+ chat_prompt_template = None ,
512
+ run_prompt_template = None ,
513
+ additional_tools = None ,
514
+ ):
515
+ if not is_openai_available ():
516
+ raise ImportError ("Using `OpenAiAgent` requires `openai`: `pip install openai`." )
517
+
518
+ self .deployment_id = deployment_id
519
+ openai .api_type = "azure"
520
+ if api_key is None :
521
+ api_key = os .environ .get ("AZURE_OPENAI_API_KEY" , None )
522
+ if api_key is None :
523
+ raise ValueError (
524
+ "You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with "
525
+ "`os.environ['AZURE_OPENAI_API_KEY'] = xxx."
526
+ )
527
+ else :
528
+ openai .api_key = api_key
529
+ if resource_name is None :
530
+ resource_name = os .environ .get ("AZURE_OPENAI_RESOURCE_NAME" , None )
531
+ if resource_name is None :
532
+ raise ValueError (
533
+ "You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with "
534
+ "`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx."
535
+ )
536
+ else :
537
+ openai .api_base = f"https://{ resource_name } .openai.azure.com"
538
+ openai .api_version = api_version
539
+
540
+ if is_chat_model is None :
541
+ is_chat_model = "gpt" in deployment_id .lower ()
542
+ self .is_chat_model = is_chat_model
543
+
544
+ super ().__init__ (
545
+ chat_prompt_template = chat_prompt_template ,
546
+ run_prompt_template = run_prompt_template ,
547
+ additional_tools = additional_tools ,
548
+ )
549
+
550
+ def generate_many (self , prompts , stop ):
551
+ if self .is_chat_model :
552
+ return [self ._chat_generate (prompt , stop ) for prompt in prompts ]
553
+ else :
554
+ return self ._completion_generate (prompts , stop )
555
+
556
+ def generate_one (self , prompt , stop ):
557
+ if self .is_chat_model :
558
+ return self ._chat_generate (prompt , stop )
559
+ else :
560
+ return self ._completion_generate ([prompt ], stop )[0 ]
561
+
562
+ def _chat_generate (self , prompt , stop ):
563
+ result = openai .ChatCompletion .create (
564
+ engine = self .deployment_id ,
565
+ messages = [{"role" : "user" , "content" : prompt }],
566
+ temperature = 0 ,
567
+ stop = stop ,
568
+ )
569
+ return result ["choices" ][0 ]["message" ]["content" ]
570
+
571
+ def _completion_generate (self , prompts , stop ):
572
+ result = openai .Completion .create (
573
+ engine = self .deployment_id ,
574
+ prompt = prompts ,
575
+ temperature = 0 ,
576
+ stop = stop ,
577
+ max_tokens = 200 ,
578
+ )
579
+ return [answer ["text" ] for answer in result ["choices" ]]
580
+
581
+
456
582
class HfAgent (Agent ):
457
583
"""
458
584
Agent that uses an inference endpoint to generate code.
0 commit comments