Skip to content

Commit

Permalink
[textanalytics] add custom named entities bespoke method (Azure#24995)
Browse files Browse the repository at this point in the history
* initial work

* docs,samples,linting

* expose TA poller

* doc fix + add poller metadata tests

* add missing recordings
  • Loading branch information
kristapratico authored and jeremydvoss committed Jul 21, 2022
1 parent 9b77c48 commit fa206cd
Show file tree
Hide file tree
Showing 31 changed files with 2,898 additions and 285 deletions.
2 changes: 2 additions & 0 deletions sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Added `begin_recognize_custom_entities` client method to recognize custom named entities in documents.

### Breaking Changes

- Removed the Extractive Text Summarization feature and related models: `ExtractSummaryAction`, `ExtractSummaryResult`, and `SummarySentence`. To access this beta feature, install the `5.2.0b4` version of the client library.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
AnalyzeHealthcareEntitiesAction,
)

from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller
from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller, TextAnalyticsLROPoller

__all__ = [
"TextAnalyticsApiVersion",
Expand Down Expand Up @@ -114,6 +114,7 @@
"ClassifyDocumentResult",
"ClassificationCategory",
"AnalyzeHealthcareEntitiesAction",
"TextAnalyticsLROPoller",
]

__version__ = VERSION
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import functools
import json
import datetime
from typing import Any, Optional
from typing import Any, Optional, MutableMapping
from urllib.parse import urlencode
from azure.core.polling._poller import PollingReturnType
from azure.core.exceptions import HttpResponseError
Expand Down Expand Up @@ -228,6 +228,9 @@ def from_continuation_token( # type: ignore
continuation_token: str,
**kwargs: Any
) -> "AnalyzeHealthcareEntitiesLROPoller": # type: ignore
"""
:meta private:
"""
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
Expand Down Expand Up @@ -457,6 +460,50 @@ def from_continuation_token( # type: ignore
continuation_token: str,
**kwargs: Any
) -> "AnalyzeActionsLROPoller": # type: ignore
"""
:meta private:
"""
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
polling_method._lro_algorithms = [ # pylint: disable=protected-access
TextAnalyticsOperationResourcePolling(
show_stats=initial_response.context.options["show_stats"]
)
]
return cls(
client,
initial_response,
functools.partial(deserialization_callback, initial_response),
polling_method
)


class TextAnalyticsLROPoller(LROPoller[PollingReturnType]):
def polling_method(self) -> AnalyzeActionsLROPollingMethod:
"""Return the polling method associated to this poller."""
return self._polling_method # type: ignore

@property
def details(self) -> MutableMapping[str, Any]:
return {
"id": self.polling_method().id,
"created_on": self.polling_method().created_on,
"expires_on": self.polling_method().expires_on,
"display_name": self.polling_method().display_name,
"last_modified_on": self.polling_method().last_modified_on,
}

@classmethod
def from_continuation_token( # type: ignore
cls,
polling_method: AnalyzeActionsLROPollingMethod,
continuation_token: str,
**kwargs: Any
) -> "TextAnalyticsLROPoller": # type: ignore
"""
:meta private:
"""
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def order_lro_results(doc_id_order, combined):
def prepare_result(func):
def choose_wrapper(*args, **kwargs):
def wrapper(
response, obj, response_headers, ordering_function
): # pylint: disable=unused-argument
response, obj, _, ordering_function
):
if hasattr(obj, "results"):
obj = obj.results # language API compat

Expand Down Expand Up @@ -280,7 +280,7 @@ def classify_document_result(


def healthcare_extract_page_data(
doc_id_order, obj, response_headers, health_job_state
doc_id_order, obj, health_job_state
): # pylint: disable=unused-argument
return (
health_job_state.next_link,
Expand All @@ -289,7 +289,7 @@ def healthcare_extract_page_data(
health_job_state.results
if hasattr(health_job_state, "results")
else health_job_state.tasks.items[0].results,
response_headers,
{},
lro=True
),
)
Expand Down Expand Up @@ -382,7 +382,7 @@ def get_ordered_errors(tasks_obj, task_name, doc_id_order):
raise ValueError("Unexpected response from service - no errors for missing action results.")


def _get_doc_results(task, doc_id_order, response_headers, returned_tasks_object):
def _get_doc_results(task, doc_id_order, returned_tasks_object):
returned_tasks = returned_tasks_object.tasks
current_task_type, task_name = task
deserialization_callback = _get_deserialization_callback_from_task_type(
Expand All @@ -401,18 +401,25 @@ def _get_doc_results(task, doc_id_order, response_headers, returned_tasks_object
if response_task_to_deserialize.results is None:
return get_ordered_errors(returned_tasks_object, task_name, doc_id_order)
return deserialization_callback(
doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
doc_id_order, response_task_to_deserialize.results, {}, lro=True
)


def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state):
def get_iter_items(doc_id_order, task_order, bespoke, analyze_job_state):
iter_items = defaultdict(list) # map doc id to action results
returned_tasks_object = analyze_job_state

if bespoke:
return _get_doc_results(
task_order[0],
doc_id_order,
returned_tasks_object,
)

for task in task_order:
results = _get_doc_results(
task,
doc_id_order,
response_headers,
returned_tasks_object,
)
for result in results:
Expand All @@ -422,11 +429,11 @@ def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state


def analyze_extract_page_data(
doc_id_order, task_order, response_headers, analyze_job_state
doc_id_order, task_order, bespoke, analyze_job_state
):
# return next link, list of
iter_items = get_iter_items(
doc_id_order, task_order, response_headers, analyze_job_state
doc_id_order, task_order, bespoke, analyze_job_state
)
return analyze_job_state.next_link, iter_items

Expand Down Expand Up @@ -456,14 +463,14 @@ def lro_get_next_page(


def healthcare_paged_result(
doc_id_order, health_status_callback, _, obj, response_headers, show_stats=False
): # pylint: disable=unused-argument
doc_id_order, health_status_callback, _, obj, show_stats=False
):
return ItemPaged(
functools.partial(
lro_get_next_page, health_status_callback, obj, show_stats=show_stats
),
functools.partial(
healthcare_extract_page_data, doc_id_order, obj, response_headers
healthcare_extract_page_data, doc_id_order, obj
),
)

Expand All @@ -474,14 +481,38 @@ def analyze_paged_result(
analyze_status_callback,
_,
obj,
response_headers,
show_stats=False,
): # pylint: disable=unused-argument
bespoke=False
):
return ItemPaged(
functools.partial(
lro_get_next_page, analyze_status_callback, obj, show_stats=show_stats
),
functools.partial(
analyze_extract_page_data, doc_id_order, task_order, response_headers
analyze_extract_page_data, doc_id_order, task_order, bespoke
),
)


def _get_result_from_continuation_token(
client, continuation_token, poller_type, polling_method, callback, bespoke=False
):
def result_callback(initial_response, pipeline_response):
doc_id_order = initial_response.context.options["doc_id_order"]
show_stats = initial_response.context.options["show_stats"]
task_id_order = initial_response.context.options.get("task_id_order")
return callback(
pipeline_response,
None,
doc_id_order,
task_id_order=task_id_order,
show_stats=show_stats,
bespoke=bespoke
)

return poller_type.from_continuation_token(
polling_method=polling_method,
client=client,
deserialization_callback=result_callback,
continuation_token=continuation_token
)
Loading

0 comments on commit fa206cd

Please sign in to comment.