Skip to content

chore: sync code base with OSS repo #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assemblyai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.33.0"
__version__ = "0.34.0"
38 changes: 19 additions & 19 deletions assemblyai/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def create_transcript(
by_alias=True,
),
)
if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}"
)
Expand All @@ -57,7 +57,7 @@ def get_transcript(
f"{ENDPOINT_TRANSCRIPT}/{transcript_id}",
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"failed to retrieve transcript {transcript_id}: {_get_error_message(response)}",
)
Expand All @@ -73,7 +73,7 @@ def delete_transcript(
f"{ENDPOINT_TRANSCRIPT}/{transcript_id}",
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"failed to delete transcript {transcript_id}: {_get_error_message(response)}",
)
Expand All @@ -100,7 +100,7 @@ def upload_file(
content=audio_file,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"Failed to upload audio file: {_get_error_message(response)}"
)
Expand All @@ -125,7 +125,7 @@ def export_subtitles_srt(
params=params,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"failed to export SRT for transcript {transcript_id}: {_get_error_message(response)}"
)
Expand All @@ -150,7 +150,7 @@ def export_subtitles_vtt(
params=params,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"failed to export VTT for transcript {transcript_id}: {_get_error_message(response)}"
)
Expand All @@ -172,7 +172,7 @@ def word_search(
),
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"failed to search words in transcript {transcript_id}: {_get_error_message(response)}"
)
Expand Down Expand Up @@ -223,7 +223,7 @@ def get_sentences(
f"{ENDPOINT_TRANSCRIPT}/{transcript_id}/sentences",
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"failed to retrieve sentences for transcript {transcript_id}: {_get_error_message(response)}"
)
Expand All @@ -239,7 +239,7 @@ def get_paragraphs(
f"{ENDPOINT_TRANSCRIPT}/{transcript_id}/paragraphs",
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.TranscriptError(
f"failed to retrieve paragraphs for transcript {transcript_id}: {_get_error_message(response)}"
)
Expand All @@ -262,7 +262,7 @@ def list_transcripts(
),
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.AssemblyAIError(
f"failed to retrieve transcripts: {_get_error_message(response)}"
)
Expand All @@ -283,7 +283,7 @@ def lemur_question(
timeout=http_timeout,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.LemurError(
f"failed to call Lemur questions: {_get_error_message(response)}"
)
Expand All @@ -304,7 +304,7 @@ def lemur_summarize(
timeout=http_timeout,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.LemurError(
f"failed to call Lemur summary: {_get_error_message(response)}"
)
Expand All @@ -325,7 +325,7 @@ def lemur_action_items(
timeout=http_timeout,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.LemurError(
f"failed to call Lemur action items: {_get_error_message(response)}"
)
Expand All @@ -346,7 +346,7 @@ def lemur_task(
timeout=http_timeout,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.LemurError(
f"failed to call Lemur task: {_get_error_message(response)}"
)
Expand All @@ -358,13 +358,13 @@ def lemur_purge_request_data(
client: httpx.Client,
request: types.LemurPurgeRequest,
http_timeout: Optional[float],
) -> types.LemurPurgeRequest:
) -> types.LemurPurgeResponse:
response = client.delete(
f"{ENDPOINT_LEMUR_BASE}/{request.request_id}",
timeout=http_timeout,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.LemurError(
f"Failed to purge LeMUR request data for provided request ID: {request.request_id}. Error: {_get_error_message(response)}"
)
Expand All @@ -374,7 +374,7 @@ def lemur_purge_request_data(

def lemur_get_response_data(
client: httpx.Client,
request_id: int,
request_id: str,
http_timeout: Optional[float],
) -> Union[
types.LemurStringResponse,
Expand All @@ -385,7 +385,7 @@ def lemur_get_response_data(
timeout=http_timeout,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.LemurError(
f"Failed to get LeMUR response data for provided request ID: {request_id}. Error: {_get_error_message(response)}"
)
Expand All @@ -409,7 +409,7 @@ def create_temporary_token(
timeout=http_timeout,
)

if response.status_code != httpx.codes.ok:
if response.status_code != httpx.codes.OK:
raise types.AssemblyAIError(
f"Failed to create temporary token: {_get_error_message(response)}"
)
Expand Down
10 changes: 9 additions & 1 deletion assemblyai/lemur.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(
self,
*,
client: _client.Client,
sources: List[types.LemurSource],
sources: Optional[List[types.LemurSource]],
) -> None:
self._client = client

Expand Down Expand Up @@ -103,6 +103,7 @@ def action_items(
def task(
self,
prompt: str,
context: Optional[Union[str, Dict[str, Any]]],
final_model: Optional[types.LemurModel],
max_output_size: Optional[int],
timeout: Optional[float],
Expand All @@ -114,6 +115,7 @@ def task(
request=types.LemurTaskRequest(
sources=self._sources,
prompt=prompt,
context=context,
final_model=final_model,
max_output_size=max_output_size,
temperature=temperature,
Expand Down Expand Up @@ -438,6 +440,7 @@ def action_items_async(
def task(
self,
prompt: str,
context: Optional[Union[str, Dict[str, Any]]] = None,
final_model: Optional[types.LemurModel] = None,
max_output_size: Optional[int] = None,
timeout: Optional[float] = None,
Expand All @@ -451,6 +454,7 @@ def task(

Args:
prompt: The prompt to use for this task.
context: An optional context on the transcript.
final_model: The model that is used for the final prompt after compression is performed.
max_output_size: Max output size in tokens
timeout: The timeout in seconds to wait for the task.
Expand All @@ -462,6 +466,7 @@ def task(

return self._impl.task(
prompt=prompt,
context=context,
final_model=final_model,
max_output_size=max_output_size,
timeout=timeout,
Expand All @@ -472,6 +477,7 @@ def task(
def task_async(
self,
prompt: str,
context: Optional[Union[str, Dict[str, Any]]] = None,
final_model: Optional[types.LemurModel] = None,
max_output_size: Optional[int] = None,
timeout: Optional[float] = None,
Expand All @@ -485,6 +491,7 @@ def task_async(

Args:
prompt: The prompt to use for this task.
context: An optional context on the transcript.
final_model: The model that is used for the final prompt after compression is performed.
max_output_size: Max output size in tokens
timeout: The timeout in seconds to wait for the task.
Expand All @@ -497,6 +504,7 @@ def task_async(
return self._executor.submit(
self._impl.task,
prompt=prompt,
context=context,
final_model=final_model,
max_output_size=max_output_size,
timeout=timeout,
Expand Down
39 changes: 18 additions & 21 deletions assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ class RawTranscriptionConfig(BaseModel):
iab_categories: Optional[bool]
"Enable Topic Detection."

custom_spelling: Optional[List[Dict[str, List[str]]]]
custom_spelling: Optional[List[Dict[str, Union[str, List[str]]]]]
"Customize how words are spelled and formatted using to and from values"

disfluencies: Optional[bool]
Expand Down Expand Up @@ -649,10 +649,11 @@ def __init__(
speech_threshold: Reject audio files that contain less than this fraction of speech. Valid values are in the range [0,1] inclusive.
raw_transcription_config: Create the config from a `RawTranscriptionConfig`
"""
self._raw_transcription_config = raw_transcription_config

if raw_transcription_config is None:
self._raw_transcription_config = RawTranscriptionConfig()
self._raw_transcription_config = (
raw_transcription_config
if raw_transcription_config is not None
else RawTranscriptionConfig()
)

# explicit configurations have higher priority if `raw_transcription_config` has been passed as well
self.language_code = language_code
Expand Down Expand Up @@ -914,17 +915,21 @@ def iab_categories(self, enable: Optional[bool]) -> None:
self._raw_transcription_config.iab_categories = enable

@property
def custom_spelling(self) -> Optional[Dict[str, List[str]]]:
def custom_spelling(self) -> Optional[Dict[str, Union[str, List[str]]]]:
"Returns the current set custom spellings."

if self._raw_transcription_config.custom_spelling is None:
return None

custom_spellings = {}
for custom_spelling in self._raw_transcription_config.custom_spelling:
custom_spellings[custom_spelling["from"]] = custom_spelling["to"]
_from = custom_spelling["from"]
if isinstance(_from, str):
custom_spellings[_from] = custom_spelling["to"]
else:
raise ValueError("`from` argument must be a string!")

return custom_spellings
return custom_spellings if custom_spelling else None

@property
def disfluencies(self) -> Optional[bool]:
Expand All @@ -938,8 +943,6 @@ def disfluencies(self, enable: Optional[bool]) -> None:

self._raw_transcription_config.disfluencies = enable

return self

@property
def sentiment_analysis(self) -> Optional[bool]:
"Returns the status of the Sentiment Analysis feature."
Expand All @@ -953,7 +956,7 @@ def sentiment_analysis(self, enable: Optional[bool]) -> None:
self._raw_transcription_config.sentiment_analysis = enable

@property
def auto_chapters(self) -> bool:
def auto_chapters(self) -> Optional[bool]:
"Returns the status of the Auto Chapters feature."

return self._raw_transcription_config.auto_chapters
Expand All @@ -971,7 +974,7 @@ def auto_chapters(self, enable: Optional[bool]) -> None:
self._raw_transcription_config.auto_chapters = enable

@property
def entity_detection(self) -> bool:
def entity_detection(self) -> Optional[bool]:
"Returns whether Entity Detection feature is enabled or not."

return self._raw_transcription_config.entity_detection
Expand Down Expand Up @@ -1076,7 +1079,7 @@ def set_casing_and_formatting(

def set_speaker_diarization(
self,
enable: bool = True,
enable: Optional[bool] = True,
speakers_expected: Optional[int] = None,
) -> Self:
"""
Expand Down Expand Up @@ -1261,7 +1264,7 @@ def set_custom_spelling(

def set_summarize(
self,
enable: bool = True,
enable: Optional[bool] = True,
model: Optional[SummarizationModel] = None,
type: Optional[SummarizationType] = None,
) -> Self:
Expand Down Expand Up @@ -1866,13 +1869,6 @@ def source(self) -> Sourcable:
"""
return self._source

@property
def context(self) -> Optional[Union[str, Dict[str, Any]]]:
"""
An optional context on the source (can be a string or an arbitrary dictionary)
"""
return self._context

@property
def type(self) -> LemurSourceType:
"""
Expand Down Expand Up @@ -2068,6 +2064,7 @@ class LemurStringResponse(BaseLemurResponse):


class LemurTaskRequest(BaseLemurRequest):
context: Optional[Union[str, Dict[str, Any]]]
prompt: str


Expand Down
3 changes: 0 additions & 3 deletions tests/unit/test_auto_chapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def test_auto_chapters_fails_without_punctuation(httpx_mock: HTTPXMock):
# Check that the error was raised before any requests were made
assert len(httpx_mock.get_requests()) == 0

# Inform httpx_mock that it's okay we didn't make any requests
httpx_mock.reset(False)


def test_auto_chapters_disabled_by_default(httpx_mock: HTTPXMock):
"""
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/test_content_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,3 @@ def test_content_safety_with_invalid_confidence_threshold(

# Check that the error was raised before any requests were made
assert len(httpx_mock.get_requests()) == 0

# Inform httpx_mock that it's okay we didn't make any requests
httpx_mock.reset(False)
5 changes: 4 additions & 1 deletion tests/unit/test_lemur.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def test_lemur_task_succeeds_transcript(httpx_mock: HTTPXMock):
lemur = aai.Lemur(
sources=[aai.LemurSource(transcript)],
)
result = lemur.task(prompt="Create action items of the meeting")
result = lemur.task(
prompt="Create action items of the meeting", context="An important meeting"
)

# check the response
assert isinstance(result, aai.LemurTaskResponse)
Expand Down Expand Up @@ -559,6 +561,7 @@ def test_lemur_task_succeeds(final_model, httpx_mock: HTTPXMock):
result = lemur.task(
final_model=final_model,
prompt="Create action items of the meeting",
context="An important meeting",
input_text="Test test",
)

Expand Down
Loading
Loading