Skip to content

Commit

Permalink
enhance:speedup xinference audio transcription (#3636)
Browse files Browse the repository at this point in the history
  • Loading branch information
leslie2046 authored Apr 23, 2024
1 parent 83caffe commit f76ac8b
Showing 1 changed file with 16 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")

if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]

# initialize client
client = Client(
base_url=credentials['server_url']
)

xinference_client = client.get_model(model_uid=credentials['model_uid'])

if not isinstance(xinference_client, RESTfulAudioModelHandle):
raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a audio model')

audio_file_path = self._get_demo_file_path()

with open(audio_file_path, 'rb') as audio_file:
Expand Down Expand Up @@ -110,17 +124,8 @@ def _speech2text_invoke(
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]

# initialize client
client = Client(
base_url=credentials['server_url']
)

xinference_client = client.get_model(model_uid=credentials['model_uid'])

if not isinstance(xinference_client, RESTfulAudioModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model')

response = xinference_client.transcriptions(
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
response = handle.transcriptions(
audio=file,
language = language,
prompt = prompt,
Expand Down

0 comments on commit f76ac8b

Please sign in to comment.