Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
SQuAD data source
Browse files Browse the repository at this point in the history
Summary: Data source specifically for the SQuAD 2.0 dataset.

Differential Revision: D14072144

fbshipit-source-id: 912f24e45a225e5c02f0c9650422e45677eaae97
  • Loading branch information
borguz authored and facebook-github-bot committed Mar 12, 2019
1 parent d0d8e88 commit 3142223
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
81 changes: 81 additions & 0 deletions pytext/data/sources/squad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
from typing import List, Optional

from pytext.data import types
from pytext.data.sources.data_source import (
DataSource,
SafeFileWrapper,
generator_property,
)


class SquadDataSource(DataSource):
"""Download data from https://rajpurkar.github.io/SQuAD-explorer/
Will return tuples of (context, question, answer, answer_start, label, weight)
"""

class Config(DataSource.Config):
train_filename: Optional[str] = "train-v2.0.json"
test_filename: Optional[str] = "dev-v2.0.json"
eval_filename: Optional[str] = "dev-v2.0.json"

@classmethod
def from_config(cls, config: Config, schema=None):
args = config._asdict()
train_filename = args.pop("train_filename")
test_filename = args.pop("test_filename")
eval_filename = args.pop("eval_filename")
train_file = SafeFileWrapper(train_filename) if train_filename else None
test_file = SafeFileWrapper(test_filename) if test_filename else None
eval_file = SafeFileWrapper(eval_filename) if eval_filename else None
return cls(train_file, test_file, eval_file, **args)

def __init__(self, train_file=None, test_file=None, eval_file=None, **kwargs):
schema = {
"context": str,
"question": str,
"answers": List[str],
"answer_starts": List[int],
"answer_ends": List[int],
"label": types.Label,
}
super().__init__(schema, **kwargs)

def flatten(file):
dump = json.load(file)
for article in dump["data"]:
for paragraph in article["paragraphs"]:
context = paragraph["context"]
for question in paragraph["qas"]:
label = not question["is_impossible"]
answers = (
question["answers"]
if label
else question["plausible_answers"]
)
yield {
"context": context,
"question": question["question"],
"answers": [answer["text"] for answer in answers],
"answer_starts": [int(ans["answer_start"]) for ans in answers],
"answer_ends": [int(ans["answer_end"]) for ans in answers],
"label": types.Label(label),
}

self._train_iter = flatten(train_file) if train_file else []
self._test_iter = flatten(test_file) if test_file else []
self._eval_iter = flatten(eval_file) if eval_file else []

@generator_property
def train(self):
return self._train_iter

@generator_property
def test(self):
return self._test_iter

@generator_property
def eval(self):
return self._eval_iter
6 changes: 6 additions & 0 deletions pytext/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ class Label(DataType, str):

class Text(DataType, str):
"""Human language text."""

class Int(DataType, int):
"""Int type."""

class Float(DataType, float):
"""Float type."""

0 comments on commit 3142223

Please sign in to comment.