Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
tutorial for new DataSource (#420)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #420

This is a tutorial to create a DataSource (using the new data handler)

Differential Revision: D14554266

fbshipit-source-id: c2204c9617b0e6ef76c77abaa21339bd73f7b349
  • Loading branch information
Titousensei authored and facebook-github-bot committed Mar 29, 2019
1 parent b21c407 commit d315123
Show file tree
Hide file tree
Showing 3 changed files with 477 additions and 0 deletions.
164 changes: 164 additions & 0 deletions demo/datasource/source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
from pytext.data.utils import UNK
from random import Random
from typing import Dict, List, Optional, Type

from pytext.data.sources.data_source import RootDataSource


def load_vocab(file_path):
"""
Given a file, prepare the vocab dictionary where each line is the value and
(line_no - 1) is the key
"""
vocab = {}
with open(file_path, "r") as file_contents:
for idx, word in enumerate(file_contents):
vocab[str(idx)] = word.strip()
return vocab


def reader(file_path, vocab):
with open(self.file_path, "r") as reader:
for line in reader:
yield " ".join(
self.vocab.get(s.strip(), UNK)
# ATIS every row starts/ends with BOS/EOS: remove them
for s in line.split()[1:-1]
)


class AtisIntentDataSource(RootDataSource):
"""
DataSource which loads queries and intent from the ATIS dataset.
The simple usage is to provide the path the unzipped atis directory, and
it will use the default filenames and parameters.
ATIS dataset has the following characteristics:
- words are stored in a dict file.
- content files contain only indices of words.
- there's no eval set: we'll take random rows from the train set.
- all queries start with BOS (Beginning Of Sentence) and end with EOS
(End Of Sentence), which we'll remove.
"""

class Config(RootDataSource.Config):
path: str = "."
field_names: List[str] = ["text", "label"]
validation_split: float = 0.25
random_seed: int = 12345
# Filenames can be overridden if necessary
intent_filename: str = "atis.dict.intent.csv"
vocab_filename: str = "atis.dict.vocab.csv"
test_queries_filename: str = "atis.test.query.csv"
test_intent_filename: str = "atis.test.intent.csv"
train_queries_filename: str = "atis.train.query.csv"
train_intent_filename: str = "atis.train.intent.csv"

# Config mimics the constructor
# This will be the default in future pytext.
@classmethod
def from_config(cls, config: Config, schema: Dict[str, Type]):
return cls(schema=schema, **config._asdict())

def __init__(
self,
path="my_directory",
field_names=["text", "label"],
validation_split=0.25,
random_seed=12345,
intent_filename="atis.dict.intent.csv",
vocab_filename="atis.dict.vocab.csv",
test_queries_filename="atis.test.query.csv",
test_intent_filename="atis.test.intent.csv",
train_queries_filename="atis.train.query.csv",
train_intent_filename="atis.train.intent.csv",
**kwargs,
):
super().__init__(**kwargs)

assert (
len(field_names or []) == 2
), "AtisIntentDataSource only handles 2 field_names: {}".format(field_names)

self.random_seed = random_seed
self.validation_split = validation_split

# Load the vocab dict in memory and provide a lookup function
# This allows other applications to
self.words = load_vocab(os.path.join(path, vocab_filename))
self.intents = load_vocab(os.path.join(path, intent_filename))

self.query_field = field_names[0]
self.intent_field = field_names[1]

self.test_queries_filepath = os.path.join(path, test_queries_filename)
self.test_intent_filepath = os.path.join(path, test_intent_filename)
self.train_queries_filepath = os.path.join(path, train_queries_filename)
self.train_intent_filepath = os.path.join(path, train_intent_filename)

def _selector(self, select_eval):
"""
This selector ensures that the same pseudo-random sequence is
always the same from the beginning. The `select_eval` parameter
guarantees that the training set and eval set are exact complements.
"""
rng = Random(self.random_seed)

def fn():
return select_eval ^ (rng.random() >= self.validation_split)

return fn

def _iter_rows(self, query_reader, intent_reader, select_fn=lambda: True):
for query_str, intent_str in zip(query_reader, intent_reader):
if select_fn():
yield {self.query_field: query_str, self.intent_field: intent_str}

def raw_train_data_generator(self):
return iter(
self._iter_rows(
query_reader=reader(self.train_queries_filepath, self.words),
intent_reader=reader(self.train_intent_filepath, self.intent),
select_fn=self._selector(select_eval=False),
)
)

def raw_eval_data_generator(self):
return iter(
self._iter_rows(
query_reader=reader(self.train_queries_filepath, self.words),
intent_reader=reader(self.train_intent_filepath, self.intent),
select_fn=self._selector(select_eval=True),
)
)

def raw_test_data_generator(self):
return iter(
self._iter_rows(
query_reader=reader(self.test_queries_filepath, self.words),
intent_reader=reader(self.test_intent_filepath, self.intent),
)
)


# Pass-through casting function for strings
@AtisIntentDataSource.register_type(str)
def load_string(s):
return s


if __name__ == "__main__":
import sys

src = AtisIntentDataSource(sys.argv[1], field_names=["query", "intent"], schema={})
for row in src.raw_train_data_generator():
print("TRAIN", row)
for row in src.raw_eval_data_generator():
print("EVAL", row)
for row in src.raw_test_data_generator():
print("TEST", row)
Loading

0 comments on commit d315123

Please sign in to comment.