Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ ASSEMBLYAI_API_KEY="<your key here>"
AWS_API_KEY="<your key here>"
AZURE_API_KEY="<your key here>"
DEEPGRAM_API_KEY="<your key here>"
ELEVATEAI_API_KEY="<your key here>"
GOOGLE_API_KEY="<your key here>"
REVAI_API_KEY="<your key here>"
SPEECHMATICS_API_KEY="<your key here>"
Expand Down
15 changes: 15 additions & 0 deletions src/rtasr/asr/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ class DeepgramOptions(TypedDict, total=False):
version: str


class ElevateAIOptions(TypedDict, total=False):
"""
The options for the ElevateAI transcription.

References from the ElevateAI docs:
https://docs.elevateai.com/transcription-features
"""

type: str
languageTag: str
vertical: str
audioTranscriptionMode: str
includeAiResults: bool


class GoogleOptions(TypedDict, total=False):
"""
The options for the Google transcription.
Expand Down
161 changes: 161 additions & 0 deletions src/rtasr/asr/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AwsOptions,
AzureOptions,
DeepgramOptions,
ElevateAIOptions,
GoogleOptions,
RevAIOptions,
SpeechmaticsOptions,
Expand All @@ -35,6 +36,9 @@
AzureOutput,
DeepgramOutput,
DeepgramUtterance,
ElevateAIOutput,
ElevateAIRedactionSegment,
ElevateAISentenceSegment,
GoogleOutput,
RevAIElement,
RevAIMonologue,
Expand Down Expand Up @@ -835,6 +839,163 @@ async def result_to_rttm(self, asr_output: DeepgramOutput) -> List[str]:
return rttm_lines


class ElevateAI(ASRProvider):
"""The ASR provider class for ElevateAI."""

def __init__(
self,
api_url: str,
api_key: str,
options: dict,
concurrency_limit: Union[int, None],
) -> None:
super().__init__(api_url, api_key, concurrency_limit)
self.options = ElevateAIOptions(**options)

@property
def output_schema(self) -> ElevateAIOutput:
"""The output format of the ElevateAI ASR provider."""
return ElevateAIOutput

async def get_transcription(
self,
audio_file: Path,
url: HttpUrl,
session: aiohttp.ClientSession,
) -> Tuple[TranscriptionStatus, ElevateAIOutput]:
"""Call the API of the ElevateAI ASR provider."""
headers = {
"Content-Type": "application/json",
"X-API-Token": f"{self.config.api_key.get_secret_value()}",
}

form = aiohttp.FormData()
form.add_field("originalFileName", audio_file.name)
for k, v in self.options.items():
if isinstance(v, (bool, dict, list, tuple)):
serialized_value = json.dumps(v)
else:
serialized_value = str(v)

form.add_field(k, serialized_value)

async with session.post(url=str(url), data=form, headers=headers) as response:
if response.status == 201 or response.status == 200:
content = (await response.text()).strip()
elif response.status == 504:
raise GatewayTimeoutError(response.status)
else:
raise Exception(await response.text())

print(content)
body = json.loads(content)
interaction_id = body.get("interactionIdentifier")

async with aiofiles.open(audio_file, mode="rb") as f:
form = aiohttp.FormData()
form.add_field(
f"{audio_file.name}",
f,
filename=audio_file.name,
content_type="application/octet-stream",
)

async with session.post(
url=f"{url}/{interaction_id}/upload", data=form, headers=headers,
) as response:
if response.status == 201 or response.status == 200:
pass
elif response.status == 504:
raise GatewayTimeoutError(response.status)
else:
raise Exception(await response.text())

headers.pop("Content-Type")
headers["Accept-Encoding"] = "gzip, deflate, br"
while True:
async with session.get(
url=f"{url}/{interaction_id}/status", headers=headers
) as response:
if response.status == 200:
content = (await response.text()).strip()

body = json.loads(content)
if body.get("status") == "processed":
status = TranscriptionStatus.COMPLETED
break
elif body.get("status") == "processingFailed":
status = TranscriptionStatus.FAILED
break
else:
await asyncio.sleep(3)
elif response.status == 504:
await asyncio.sleep(3)
else:
raise Exception(await response.text())

if status == TranscriptionStatus.COMPLETED:
headers["Content-Type"] = "application/json"
async with session.get(
url=f"{url}/{interaction_id}/transcripts/punctuated", headers=headers
) as response:
if response.status == 200:
content = (await response.text()).strip()
else:
raise Exception(await response.text())

body = json.loads(content)
asr_output = ElevateAIOutput.from_json(body)

else:
asr_output = body.get("errorMessage")

return status, asr_output

def _order_results(
self, asr_output: ElevateAIOutput
) -> List[Union[ElevateAIRedactionSegment, ElevateAISentenceSegment]]:
"""Order the results of the ElevateAI ASR provider."""
sentences: List[ElevateAISentenceSegment] = asr_output.sentenceSegments
redactions: List[ElevateAIRedactionSegment] = asr_output.redactionSegments

# We need to sort the sentences and redactions by their start time.
utterances = sorted(
sentences + redactions,
key=lambda x: x.startTimeOffset,
)

return utterances

async def result_to_dialogue(self, asr_output: ElevateAIOutput) -> List[str]:
"""Convert the result to dialogue format for WER."""
utterances = self._order_results(asr_output)

dialogue_lines: List[str] = []
for utterance in utterances:
if isinstance(utterance, ElevateAISentenceSegment):
dialogue_lines.append(utterance.phrase)
elif isinstance(utterance, ElevateAIRedactionSegment):
dialogue_lines.append(utterance.redaction)

return dialogue_lines

async def result_to_rttm(self, asr_output: ElevateAIOutput) -> List[str]:
"""Convert the result to RTTM format for DER."""
utterances = (
asr_output.sentenceSegments
) # We skip redaction that don't have a speaker attribute

rttm_lines: List[str] = []
for utterance in utterances:
start_seconds: float = utterance.startTimeOffset / 1000
end_seconds: float = utterance.endTimeOffset / 1000
speaker: int = utterance.participant

rttm_lines.append(f"{start_seconds} {end_seconds} {speaker}")

return rttm_lines


class Google(ASRProvider):
"""The ASR provider class for Google."""

Expand Down
26 changes: 26 additions & 0 deletions src/rtasr/asr/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,32 @@ class DeepgramOutput(ASROutput):
results: DeepgramResult


class ElevateAIRedactionSegment(BaseModel):
"""ElevateAI redaction segment schema."""

endTimeOffset: int
result: str
score: float
startTimeOffset: int


class ElevateAISentenceSegment(BaseModel):
"""ElevateAI segment schema."""

endTimeOffset: int
participant: str
phrase: str
score: float
startTimeOffset: int


class ElevateAIOutput(ASROutput):
"""ElevateAI output schema."""

redactionSegments: List[ElevateAIRedactionSegment]
sentenceSegments: List[ElevateAISentenceSegment]


class GoogleOutput(ASROutput):
"""Google output schema."""

Expand Down
21 changes: 21 additions & 0 deletions src/rtasr/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,27 @@
},
},
),
(
"elevateai",
{
"url": "https://api.elevateai.com/v1/interactions",
"engine": "ElevateAI",
"output": "ElevateAIOutput",
"speaker_map": "ElevateAISpeakerMap",
"concurrency_limit": 5,
"options": {
"type": "audio",
"languageTag": "en",
"vertical": "default",
"audioTranscriptionMode": "highAccuracy",
"includeAiResults": True,
},
"pricing": {
"value": 0.0030,
"unit": "minute",
},
},
),
(
"google",
{
Expand Down
1 change: 1 addition & 0 deletions src/rtasr/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ProviderNameDisplay(str, Enum):
aws = "AWS"
azure = "Azure"
deepgram = "Deepgram"
elevateai = "ElevateAI"
google = "Google"
revai = "Rev.ai"
speechmatics = "Speechmatics"
Expand Down
46 changes: 46 additions & 0 deletions src/rtasr/speaker_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,52 @@ def from_value(cls, speaker_id: int) -> str:
return cls[_speaker_id].value


# TODO: The current ElevateAI speaker map only supports up to 2 speakers...
# But we'll keep the rest of the participants here for future use.
class ElevateAISpeakerMap(str, Enum):
"""ElevateAI speaker map."""

participantOne = "A"
participantTwo = "B"
participantThree = "C"
participantFour = "D"
participantFive = "E"
participantSix = "F"
participantSeven = "G"
participantEight = "H"
participantNine = "I"
participantTen = "J"
participantEleven = "K"
participantTwelve = "L"
participantThirteen = "M"
participantFourteen = "N"
participantFifteen = "O"
participantSixteen = "P"
participantSeventeen = "Q"
participantEighteen = "R"
participantNineteen = "S"
participantTwenty = "T"
participantTwentyOne = "U"
participantTwentyTwo = "V"
participantTwentyThree = "W"
participantTwentyFour = "X"
participantTwentyFive = "Y"
participantTwentySix = "Z"

@classmethod
def from_value(cls, speaker_id: str) -> str:
"""Get speaker map from a string value."""
if speaker_id not in cls.__members__:
raise ValueError(
f"Speaker name {speaker_id} not found in speaker map.HINT: Speaker"
" names are in the format `participantX` where `X` is a number, "
"between `One` and `TwentySix`. For example, `participantOne` or"
"`participantTwentySix` are valid speaker names."
)
else:
return cls[speaker_id].value


class GoogleSpeakerMap(str, Enum):
"""Google speaker map."""

Expand Down
21 changes: 21 additions & 0 deletions tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,27 @@ def test_providers_deepgram(self) -> None:
"unit": "minute",
}

def test_providers_elevateai(self) -> None:
"""Test ElevateAI provider."""
assert (
PROVIDERS["elevateai"]["url"] == "https://api.elevateai.com/v1/interactions"
)
assert PROVIDERS["elevateai"]["engine"] == "ElevateAI"
assert PROVIDERS["elevateai"]["output"] == "ElevateAIOutput"
assert PROVIDERS["elevateai"]["speaker_map"] == "ElevateAISpeakerMap"
assert PROVIDERS["elevateai"]["concurrency_limit"] == 5
assert PROVIDERS["elevateai"]["options"] == {
"type": "audio",
"languageTag": "en",
"vertical": "default",
"audioTranscriptionMode": "highAccuracy",
"includeAiResults": True,
}
assert PROVIDERS["elevateai"]["pricing"] == {
"value": 0.0030,
"unit": "minute",
}

def test_providers_google(self) -> None:
"""Test Google provider."""
assert PROVIDERS["google"]["url"] == ""
Expand Down
Loading