Skip to content

Improve QA pipeline error handling #8286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 36 additions & 38 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import uuid
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable
from contextlib import contextmanager
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -1597,55 +1598,52 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
command-line supplied arguments.
"""

def normalize(self, item):
if isinstance(item, SquadExample):
return item
elif isinstance(item, dict):
for k in ["question", "context"]:
if k not in item:
raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
elif item[k] is None:
raise ValueError("`{}` cannot be None".format(k))
elif isinstance(item[k], str) and len(item[k]) == 0:
raise ValueError("`{}` cannot be empty".format(k))

return QuestionAnsweringPipeline.create_sample(**item)
raise ValueError("{} argument needs to be of type (SquadExample, dict)".format(item))

def __call__(self, *args, **kwargs):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
# Detect where the actual inputs are
if args is not None and len(args) > 0:
if len(args) == 1:
kwargs["X"] = args[0]
inputs = args[0]
elif len(args) == 2 and {type(el) for el in args} == {str}:
inputs = [{"question": args[0], "context": args[1]}]
else:
kwargs["X"] = list(args)

inputs = list(args)
# Generic compatibility with sklearn and Keras
# Batched data
if "X" in kwargs or "data" in kwargs:
inputs = kwargs["X"] if "X" in kwargs else kwargs["data"]

if isinstance(inputs, dict):
inputs = [inputs]
else:
# Copy to avoid overriding arguments
inputs = [i for i in inputs]

for i, item in enumerate(inputs):
if isinstance(item, dict):
if any(k not in item for k in ["question", "context"]):
raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")

inputs[i] = QuestionAnsweringPipeline.create_sample(**item)

elif not isinstance(item, SquadExample):
raise ValueError(
"{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
"X" if "X" in kwargs else "data"
)
)

# Tabular input
elif "X" in kwargs:
inputs = kwargs["X"]
elif "data" in kwargs:
inputs = kwargs["data"]
elif "question" in kwargs and "context" in kwargs:
if isinstance(kwargs["question"], str):
kwargs["question"] = [kwargs["question"]]

if isinstance(kwargs["context"], str):
kwargs["context"] = [kwargs["context"]]

inputs = [
QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
]
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
else:
raise ValueError("Unknown arguments {}".format(kwargs))

if not isinstance(inputs, list):
# Normalize inputs
if isinstance(inputs, dict):
inputs = [inputs]
elif isinstance(inputs, Iterable):
# Copy to avoid overriding arguments
inputs = [i for i in inputs]
else:
raise ValueError("Invalid arguments {}".format(inputs))

for i, item in enumerate(inputs):
inputs[i] = self.normalize(item)

return inputs

Expand Down
118 changes: 115 additions & 3 deletions tests/test_pipelines_question_answering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

from transformers.pipelines import Pipeline
from transformers.data.processors.squad import SquadExample
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler

from .test_pipelines_common import CustomInputPipelineCommonMixin

Expand Down Expand Up @@ -43,5 +44,116 @@ def _test_pipeline(self, nlp: Pipeline):
for key in output_keys:
self.assertIn(key, result)
for bad_input in invalid_inputs:
self.assertRaises(Exception, nlp, bad_input)
self.assertRaises(Exception, nlp, invalid_inputs)
self.assertRaises(ValueError, nlp, bad_input)
self.assertRaises(ValueError, nlp, invalid_inputs)

def test_argument_handler(self):
qa = QuestionAnsweringArgumentHandler()

Q = "Where was HuggingFace founded ?"
C = "HuggingFace was founded in Paris"

normalized = qa(Q, C)
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

normalized = qa(question=Q, context=C)
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

normalized = qa(question=Q, context=C)
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

normalized = qa({"question": Q, "context": C})
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

normalized = qa([{"question": Q, "context": C}])
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

normalized = qa([{"question": Q, "context": C}, {"question": Q, "context": C}])
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 2)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

normalized = qa(X={"question": Q, "context": C})
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

normalized = qa(X=[{"question": Q, "context": C}])
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

normalized = qa(data={"question": Q, "context": C})
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})

def test_argument_handler_error_handling(self):
qa = QuestionAnsweringArgumentHandler()

Q = "Where was HuggingFace founded ?"
C = "HuggingFace was founded in Paris"

with self.assertRaises(KeyError):
qa({"context": C})
with self.assertRaises(KeyError):
qa({"question": Q})
with self.assertRaises(KeyError):
qa([{"context": C}])
with self.assertRaises(ValueError):
qa(None, C)
with self.assertRaises(ValueError):
qa("", C)
with self.assertRaises(ValueError):
qa(Q, None)
with self.assertRaises(ValueError):
qa(Q, "")

with self.assertRaises(ValueError):
qa(question=None, context=C)
with self.assertRaises(ValueError):
qa(question="", context=C)
with self.assertRaises(ValueError):
qa(question=Q, context=None)
with self.assertRaises(ValueError):
qa(question=Q, context="")

with self.assertRaises(ValueError):
qa({"question": None, "context": C})
with self.assertRaises(ValueError):
qa({"question": "", "context": C})
with self.assertRaises(ValueError):
qa({"question": Q, "context": None})
with self.assertRaises(ValueError):
qa({"question": Q, "context": ""})

with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": None, "context": C}])
with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": "", "context": C}])

with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": Q, "context": None}])
with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": Q, "context": ""}])

def test_argument_handler_error_handling_odd(self):
qa = QuestionAnsweringArgumentHandler()
with self.assertRaises(ValueError):
qa(None)

with self.assertRaises(ValueError):
qa(Y=None)

with self.assertRaises(ValueError):
qa(1)