Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[textanalytics] add custom named entities bespoke method #24995

Merged
merged 5 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Added `begin_recognize_custom_entities` client method to recognize custom named entities in documents.

### Breaking Changes

- Removed the Extractive Text Summarization feature and related models: `ExtractSummaryAction`, `ExtractSummaryResult`, and `SummarySentence`. To access this beta feature, install the `5.2.0b4` version of the client library.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
AnalyzeHealthcareEntitiesAction,
)

from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller
from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller, TextAnalyticsLROPoller

__all__ = [
"TextAnalyticsApiVersion",
Expand Down Expand Up @@ -114,6 +114,7 @@
"ClassifyDocumentResult",
"ClassificationCategory",
"AnalyzeHealthcareEntitiesAction",
"TextAnalyticsLROPoller",
kristapratico marked this conversation as resolved.
Show resolved Hide resolved
]

__version__ = VERSION
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import functools
import json
import datetime
from typing import Any, Optional
from typing import Any, Optional, MutableMapping
from urllib.parse import urlencode
from azure.core.polling._poller import PollingReturnType
from azure.core.exceptions import HttpResponseError
Expand Down Expand Up @@ -228,6 +228,9 @@ def from_continuation_token( # type: ignore
continuation_token: str,
**kwargs: Any
) -> "AnalyzeHealthcareEntitiesLROPoller": # type: ignore
"""
:meta private:
catalinaperalta marked this conversation as resolved.
Show resolved Hide resolved
"""
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
Expand Down Expand Up @@ -457,6 +460,50 @@ def from_continuation_token( # type: ignore
continuation_token: str,
**kwargs: Any
) -> "AnalyzeActionsLROPoller": # type: ignore
"""
:meta private:
"""
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
polling_method._lro_algorithms = [ # pylint: disable=protected-access
TextAnalyticsOperationResourcePolling(
show_stats=initial_response.context.options["show_stats"]
)
]
return cls(
client,
initial_response,
functools.partial(deserialization_callback, initial_response),
polling_method
)


class TextAnalyticsLROPoller(LROPoller[PollingReturnType]):
def polling_method(self) -> AnalyzeActionsLROPollingMethod:
"""Return the polling method associated to this poller."""
return self._polling_method # type: ignore

@property
def details(self) -> MutableMapping[str, Any]:
return {
"id": self.polling_method().id,
"created_on": self.polling_method().created_on,
"expires_on": self.polling_method().expires_on,
"display_name": self.polling_method().display_name,
"last_modified_on": self.polling_method().last_modified_on,
}

@classmethod
def from_continuation_token( # type: ignore
cls,
polling_method: AnalyzeActionsLROPollingMethod,
continuation_token: str,
**kwargs: Any
) -> "TextAnalyticsLROPoller": # type: ignore
"""
:meta private:
"""
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def order_lro_results(doc_id_order, combined):
def prepare_result(func):
def choose_wrapper(*args, **kwargs):
def wrapper(
response, obj, response_headers, ordering_function
kristapratico marked this conversation as resolved.
Show resolved Hide resolved
): # pylint: disable=unused-argument
response, obj, _, ordering_function
):
if hasattr(obj, "results"):
obj = obj.results # language API compat

Expand Down Expand Up @@ -280,7 +280,7 @@ def classify_document_result(


def healthcare_extract_page_data(
doc_id_order, obj, response_headers, health_job_state
doc_id_order, obj, health_job_state
): # pylint: disable=unused-argument
return (
health_job_state.next_link,
Expand All @@ -289,7 +289,7 @@ def healthcare_extract_page_data(
health_job_state.results
if hasattr(health_job_state, "results")
else health_job_state.tasks.items[0].results,
response_headers,
{},
lro=True
),
)
Expand Down Expand Up @@ -382,7 +382,7 @@ def get_ordered_errors(tasks_obj, task_name, doc_id_order):
raise ValueError("Unexpected response from service - no errors for missing action results.")


def _get_doc_results(task, doc_id_order, response_headers, returned_tasks_object):
def _get_doc_results(task, doc_id_order, returned_tasks_object):
returned_tasks = returned_tasks_object.tasks
current_task_type, task_name = task
deserialization_callback = _get_deserialization_callback_from_task_type(
Expand All @@ -401,18 +401,25 @@ def _get_doc_results(task, doc_id_order, response_headers, returned_tasks_object
if response_task_to_deserialize.results is None:
return get_ordered_errors(returned_tasks_object, task_name, doc_id_order)
return deserialization_callback(
doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
doc_id_order, response_task_to_deserialize.results, {}, lro=True
kristapratico marked this conversation as resolved.
Show resolved Hide resolved
)


def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state):
def get_iter_items(doc_id_order, task_order, bespoke, analyze_job_state):
iter_items = defaultdict(list) # map doc id to action results
returned_tasks_object = analyze_job_state

if bespoke:
return _get_doc_results(
task_order[0],
doc_id_order,
returned_tasks_object,
)

for task in task_order:
results = _get_doc_results(
task,
doc_id_order,
response_headers,
returned_tasks_object,
)
for result in results:
Expand All @@ -422,11 +429,11 @@ def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state


def analyze_extract_page_data(
doc_id_order, task_order, response_headers, analyze_job_state
doc_id_order, task_order, bespoke, analyze_job_state
):
# return next link, list of
iter_items = get_iter_items(
doc_id_order, task_order, response_headers, analyze_job_state
doc_id_order, task_order, bespoke, analyze_job_state
)
return analyze_job_state.next_link, iter_items

Expand Down Expand Up @@ -456,14 +463,14 @@ def lro_get_next_page(


def healthcare_paged_result(
doc_id_order, health_status_callback, _, obj, response_headers, show_stats=False
): # pylint: disable=unused-argument
doc_id_order, health_status_callback, _, obj, show_stats=False
):
return ItemPaged(
functools.partial(
lro_get_next_page, health_status_callback, obj, show_stats=show_stats
),
functools.partial(
healthcare_extract_page_data, doc_id_order, obj, response_headers
healthcare_extract_page_data, doc_id_order, obj
),
)

Expand All @@ -474,14 +481,38 @@ def analyze_paged_result(
analyze_status_callback,
_,
obj,
response_headers,
show_stats=False,
): # pylint: disable=unused-argument
bespoke=False
):
return ItemPaged(
functools.partial(
lro_get_next_page, analyze_status_callback, obj, show_stats=show_stats
),
functools.partial(
analyze_extract_page_data, doc_id_order, task_order, response_headers
analyze_extract_page_data, doc_id_order, task_order, bespoke
),
)


def _get_result_from_continuation_token(
client, continuation_token, poller_type, polling_method, callback, bespoke=False
):
def result_callback(initial_response, pipeline_response):
doc_id_order = initial_response.context.options["doc_id_order"]
show_stats = initial_response.context.options["show_stats"]
task_id_order = initial_response.context.options.get("task_id_order")
return callback(
pipeline_response,
None,
doc_id_order,
task_id_order=task_id_order,
show_stats=show_stats,
bespoke=bespoke
)

return poller_type.from_continuation_token(
polling_method=polling_method,
client=client,
deserialization_callback=result_callback,
continuation_token=continuation_token
)
Loading