Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
SQuAD data source (#382)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #382

Data source specifically for the SQuAD 2.0 dataset.

Differential Revision: D14072144

fbshipit-source-id: 7e9765d1fa8e52a1ea641c50d7c97f59eacfc8d3
  • Loading branch information
borguz authored and facebook-github-bot committed Mar 12, 2019
1 parent d0d8e88 commit 634406a
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
84 changes: 84 additions & 0 deletions pytext/data/sources/squad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
from typing import List, Optional

from pytext.data import types
from pytext.data.sources.data_source import DataSource, generator_property


def flatten(fname, ignore_impossible):
if not fname:
return
with open(fname) as file:
dump = json.load(file)
for article in dump["data"]:
for paragraph in article["paragraphs"]:
context = paragraph["context"]
for question in paragraph["qas"]:
label = not question["is_impossible"]
if label or not ignore_impossible:
answers = (
question["answers"] if label else question["plausible_answers"]
)
yield {
"context": context,
"question": question["question"],
"answers": [answer["text"] for answer in answers],
"answer_starts": [int(ans["answer_start"]) for ans in answers],
"label": types.Label(label),
}


class SquadDataSource(DataSource):
"""Download data from https://rajpurkar.github.io/SQuAD-explorer/
Will return tuples of (context, question, answer, answer_start, label, weight)
"""

class Config(DataSource.Config):
train_filename: Optional[str] = "train-v2.0.json"
test_filename: Optional[str] = "dev-v2.0.json"
eval_filename: Optional[str] = "dev-v2.0.json"
ignore_impossible: bool = True

@classmethod
def from_config(cls, config: Config, schema=None):
return cls(
config.train_filename,
config.test_filename,
config.eval_filename,
config.ignore_impossible,
)

def __init__(
self,
train_filename=None,
test_filename=None,
eval_filename=None,
ignore_impossible=Config.ignore_impossible,
):
schema = {
"context": str,
"question": str,
"answers": List[str],
"answer_starts": List[int],
"answer_ends": List[int],
"label": types.Label,
}
super().__init__(schema)
self.train_filename = train_filename
self.test_filename = test_filename
self.eval_filename = eval_filename
self.ignore_impossible = ignore_impossible

@generator_property
def train(self):
return flatten(self.train_filename, self.ignore_impossible)

@generator_property
def test(self):
return flatten(self.test_filename, self.ignore_impossible)

@generator_property
def eval(self):
return flatten(self.eval_filename, self.ignore_impossible)
6 changes: 6 additions & 0 deletions pytext/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ class Label(DataType, str):

class Text(DataType, str):
"""Human language text."""

class Int(DataType, int):
"""Int type."""

class Float(DataType, float):
"""Float type."""

0 comments on commit 634406a

Please sign in to comment.