Skip to content

Commit

Permalink
Georgedyre/openai upgrade (#2) (#33369)
Browse files Browse the repository at this point in the history
* Keep original question for first in conversation

* Add OpenAI version 1.0 support

* Fix openai client

---------

Co-authored-by: George Dyre <georgedyre@microsoft.com>
  • Loading branch information
gdyre and georgedyre authored Dec 5, 2023
1 parent 2bcc83c commit cb678ce
Showing 1 changed file with 76 additions and 26 deletions.
102 changes: 76 additions & 26 deletions sdk/ai/azure-ai-generative/azure/ai/generative/synthetic/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from enum import Enum
from functools import lru_cache
from typing import Dict, List, Tuple, Any, Union
import openai
from collections import defaultdict
from azure.ai.resources.entities import BaseConnection
from azure.identity import DefaultAzureCredential
Expand All @@ -22,27 +21,62 @@
print("In order to use qa, please install the 'qa_generation' extra of azure-ai-generative")
raise e

try:
import pkg_resources
openai_version_str = pkg_resources.get_distribution("openai").version
openai_version = pkg_resources.parse_version(openai_version_str)
import openai
if openai_version >= pkg_resources.parse_version("1.0.0"):
_RETRY_ERRORS = (
openai.APIConnectionError ,
openai.APIError,
openai.APIStatusError
)
else:
_RETRY_ERRORS = (
openai.error.ServiceUnavailableError,
openai.error.APIError,
openai.error.RateLimitError,
openai.error.APIConnectionError,
openai.error.Timeout,
)

except ImportError as e:
print("In order to use qa, please install the 'qa_generation' extra of azure-ai-generative")
raise e

_TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
activity_logger = ActivityLogger(__name__)
logger, module_logger = activity_logger.package_logger, activity_logger.module_logger

_DEFAULT_AOAI_VERSION = "2023-07-01-preview"
_MAX_RETRIES = 7
_RETRY_ERRORS = (
openai.error.ServiceUnavailableError,
openai.error.APIError,
openai.error.RateLimitError,
openai.error.APIConnectionError,
openai.error.Timeout,
)



def _completion_with_retries(*args, **kwargs):
n = 1
while True:
try:
response = openai.ChatCompletion.create(*args, **kwargs)
if openai_version >= pkg_resources.parse_version("1.0.0"):
if kwargs["api_type"].lower() == "azure":
from openai import AzureOpenAI
client = AzureOpenAI(
azure_endpoint = kwargs["api_base"],
api_key=kwargs["api_key"],
api_version=kwargs["api_version"]
)
response = client.chat.completions.create(messages=kwargs["messages"], model=kwargs["deployment_id"], temperature=kwargs["temperature"], max_tokens=kwargs["max_tokens"])
else:
from openai import OpenAI
client = OpenAI(
api_key=kwargs["api_key"],
)
response = client.chat.completions.create(messages=kwargs["messages"], model=kwargs["model"], temperature=kwargs["temperature"], max_tokens=kwargs["max_tokens"])
return response.choices[0].message.content, dict(response.usage)
else:
response = openai.ChatCompletion.create(*args, **kwargs)
return response["choices"][0].message.content, response["usage"]
except _RETRY_ERRORS as e:
if n > _MAX_RETRIES:
raise
Expand All @@ -51,14 +85,31 @@ def _completion_with_retries(*args, **kwargs):
time.sleep(secs)
n += 1
continue
return response


async def _completion_with_retries_async(*args, **kwargs):
n = 1
while True:
try:
response = await openai.ChatCompletion.acreate(*args, **kwargs)
if openai_version >= pkg_resources.parse_version("1.0.0"):
if kwargs["api_type"].lower() == "azure":
from openai import AsyncAzureOpenAI
client = AsyncAzureOpenAI(
azure_endpoint = kwargs["api_base"],
api_key=kwargs["api_key"],
api_version=kwargs["api_version"]
)
response = await client.chat.completions.create(messages=kwargs["messages"], model=kwargs["deployment_id"], temperature=kwargs["temperature"], max_tokens=kwargs["max_tokens"])
else:
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key=kwargs["api_key"],
)
response = await client.chat.completions.create(messages=kwargs["messages"], model=kwargs["model"], temperature=kwargs["temperature"], max_tokens=kwargs["max_tokens"])
return response.choices[0].message.content, dict(response.usage)
else:
response = openai.ChatCompletion.create(*args, **kwargs)
return response["choices"][0].message.content, response["usage"]
except _RETRY_ERRORS as e:
if n > _MAX_RETRIES:
raise
Expand All @@ -67,7 +118,6 @@ async def _completion_with_retries_async(*args, **kwargs):
await asyncio.sleep(secs)
n += 1
continue
return response

class OutputStructure(str, Enum):
"""OutputStructure defines what structure the QAs should be written to file in."""
Expand Down Expand Up @@ -190,15 +240,16 @@ def _merge_token_usage(self, token_usage: Dict, token_usage2: Dict) -> Dict:
return {name: count + token_usage[name] for name, count in token_usage2.items()}

def _modify_conversation_questions(self, questions) -> Tuple[List[str], Dict]:
response = _completion_with_retries(
content, usage = _completion_with_retries(
messages=self._get_messages_for_modify_conversation(questions),
**self._chat_completion_params,
)
modified_questions, _ = self._parse_qa_from_response(response["choices"][0].message.content)
# Don't modify first question of conversation

modified_questions, _ = self._parse_qa_from_response(content)
# Keep proper nouns in first question of conversation
modified_questions[0] = questions[0]
assert len(modified_questions) == len(questions), self._PARSING_ERR_UNEQUAL_Q_AFTER_MOD
return modified_questions, response["usage"]
return modified_questions, usage

@distributed_trace
@monitor_with_activity(logger, "QADataGenerator.Export", ActivityType.INTERNALCALL)
Expand Down Expand Up @@ -266,13 +317,12 @@ def export_to_file(self, output_path: str, qa_type: QAType, results: Union[List,
@monitor_with_activity(logger, "QADataGenerator.Generate", ActivityType.INTERNALCALL)
def generate(self, text: str, qa_type: QAType, num_questions: int = None) -> Dict:
self._validate(qa_type, num_questions)
response = _completion_with_retries(
content, token_usage = _completion_with_retries(
messages=self._get_messages_for_qa_type(qa_type, text, num_questions),
**self._chat_completion_params,
)
questions, answers = self._parse_qa_from_response(response["choices"][0].message.content)
questions, answers = self._parse_qa_from_response(content)
assert len(questions) == len(answers), self._PARSING_ERR_UNEQUAL_QA
token_usage = response["usage"]
if qa_type == QAType.CONVERSATION:
questions, token_usage2 = self._modify_conversation_questions(questions)
token_usage = self._merge_token_usage(token_usage, token_usage2)
Expand All @@ -282,27 +332,27 @@ def generate(self, text: str, qa_type: QAType, num_questions: int = None) -> Dic
}

async def _modify_conversation_questions_async(self, questions) -> Tuple[List[str], Dict]:
response = await _completion_with_retries_async(
content, usage = await _completion_with_retries_async(
messages=self._get_messages_for_modify_conversation(questions),
**self._chat_completion_params,
)
modified_questions, _ = self._parse_qa_from_response(response["choices"][0].message.content)
# Don't modify first question of conversation

modified_questions, _ = self._parse_qa_from_response(content)
# Keep proper nouns in first question of conversation
modified_questions[0] = questions[0]
assert len(modified_questions) == len(questions), self._PARSING_ERR_UNEQUAL_Q_AFTER_MOD
return modified_questions, response["usage"]
return modified_questions, usage

@distributed_trace
@monitor_with_activity(logger, "QADataGenerator.GenerateAsync", ActivityType.INTERNALCALL)
async def generate_async(self, text: str, qa_type: QAType, num_questions: int = None) -> Dict:
self._validate(qa_type, num_questions)
response = await _completion_with_retries_async(
content, token_usage = await _completion_with_retries_async(
messages=self._get_messages_for_qa_type(qa_type, text, num_questions),
**self._chat_completion_params,
)
questions, answers = self._parse_qa_from_response(response["choices"][0].message.content)
questions, answers = self._parse_qa_from_response(content)
assert len(questions) == len(answers), self._PARSING_ERR_UNEQUAL_QA
token_usage = response["usage"]
if qa_type == QAType.CONVERSATION:
questions, token_usage2 = await self._modify_conversation_questions_async(questions)
token_usage = self._merge_token_usage(token_usage, token_usage2)
Expand Down

0 comments on commit cb678ce

Please sign in to comment.