Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes in UIETask to remove mutliple schema construction #2170

Merged
merged 3 commits into from
May 16, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions paddlenlp/taskflow/information_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import paddle
from ..datasets import load_dataset
Expand Down Expand Up @@ -121,6 +119,7 @@ class UIETask(Task):

def __init__(self, task, model, schema, **kwargs):
super().__init__(task=task, model=model, **kwargs)
self._schema_tree = None
self.set_schema(schema)
if model not in self.encoding_model_map.keys():
raise ValueError(
Expand All @@ -142,8 +141,7 @@ def __init__(self, task, model, schema, **kwargs):
def set_schema(self, schema):
if isinstance(schema, dict) or isinstance(schema, str):
schema = [schema]
self._schema = schema
self._build_tree(self._schema)
self._schema_tree = self._build_tree(schema)

def _construct_input_spec(self):
"""
Expand Down Expand Up @@ -330,31 +328,42 @@ def _auto_joiner(self, short_results, short_inputs, input_mapping):

def _run_model(self, inputs):
raw_inputs = inputs['text']
schema_tree = self._build_tree(self._schema)
results = self._multi_stage_predict(raw_inputs, schema_tree)
results = self._multi_stage_predict(raw_inputs)
inputs['result'] = results
return inputs

def _multi_stage_predict(self, datas, schema_tree):
def _multi_stage_predict(self, datas):
"""
Traversal the schema tree and do multi-stage prediction.

Args:
datas (list): a list of strings

Returns:
list: a list of predictions, where the list's length
equals to the length of `datas`
"""
results = [{} for i in range(len(datas))]
schema_list = schema_tree.children
results = [{} for _ in range(len(datas))]
# input check to early return
if len(datas) < 1 or self._schema_tree is None:
return results

# copy to stay `self._schema_tree` unchanged
schema_list = self._schema_tree.children[:]
while len(schema_list) > 0:
node = schema_list.pop(0)
examples = []
input_map = {}
cnt = 0
id = 0
idx = 0
if not node.prefix:
for data in datas:
examples.append({
"text": data,
"prompt": dbc2sbc(node.name)
})
input_map[cnt] = [id]
id += 1
input_map[cnt] = [idx]
idx += 1
cnt += 1
else:
for pre, data in zip(node.prefix, datas):
Expand All @@ -366,8 +375,8 @@ def _multi_stage_predict(self, datas, schema_tree):
"text": data,
"prompt": dbc2sbc(p + node.name)
})
input_map[cnt] = [i + id for i in range(len(pre))]
id += len(pre)
input_map[cnt] = [i + idx for i in range(len(pre))]
idx += len(pre)
cnt += 1
if len(examples) == 0:
result_list = []
Expand All @@ -377,13 +386,13 @@ def _multi_stage_predict(self, datas, schema_tree):
if not node.parent_relations:
relations = [[] for i in range(len(datas))]
for k, v in input_map.items():
for id in v:
if len(result_list[id]) == 0:
for idx in v:
if len(result_list[idx]) == 0:
continue
if node.name not in results[k].keys():
results[k][node.name] = result_list[id]
results[k][node.name] = result_list[idx]
else:
results[k][node.name].extend(result_list[id])
results[k][node.name].extend(result_list[idx])
if node.name in results[k].keys():
relations[k].extend(results[k][node.name])
else:
Expand Down Expand Up @@ -415,11 +424,11 @@ def _multi_stage_predict(self, datas, schema_tree):
"relations"][node.name][k])
relations = new_relations

prefix = [[] for i in range(len(datas))]
prefix = [[] for _ in range(len(datas))]
for k, v in input_map.items():
for id in v:
for i in range(len(result_list[id])):
prefix[k].append(result_list[id][i]["text"] + "的")
for idx in v:
for i in range(len(result_list[idx])):
prefix[k].append(result_list[idx][i]["text"] + "的")

for child in node.children:
child.prefix = prefix
Expand Down Expand Up @@ -459,7 +468,8 @@ def _convert_ids_to_results(self, examples, sentence_ids, probs):
results.append(result_list)
return results

def _build_tree(self, schema, name='root'):
@classmethod
def _build_tree(cls, schema, name='root'):
"""
Build the schema tree.
"""
Expand All @@ -477,7 +487,7 @@ def _build_tree(self, schema, name='root'):
raise TypeError(
"Invalid schema, value for each key:value pairs should be list or string"
"but {} received".format(type(v)))
schema_tree.add_child(self._build_tree(child, name=k))
schema_tree.add_child(cls._build_tree(child, name=k))
else:
raise TypeError(
"Invalid schema, element should be string or dict, "
Expand Down